Skip to content

Commit 0f66258

Browse files
committed
update doc stringh
1 parent 466e469 commit 0f66258

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

examples/ppo_countdown/countdown.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ algorithm:
77
model:
88
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
99
max_prompt_tokens: 256
10-
max_repsonse_tokens: 1024
10+
max_response_tokens: 1024
1111
cluster:
1212
node_num: 1
1313
gpu_per_node: 8

tests/common/config_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,25 @@
44
import unittest
55

66
from tests.tools import get_template_config
7-
from trinity.common.config import load_config
7+
from trinity.common.config import InferenceModelConfig, load_config
88

99

1010
class TestConfig(unittest.TestCase):
1111
def test_load_default_config(self):
1212
config = get_template_config()
13+
config.buffer.batch_size = 8
1314
config.algorithm.repeat_times = 10
1415
config.model.model_path = "Qwen/Qwen3-1.7B"
16+
config.cluster.gpu_per_node = 8
17+
config.cluster.node_num = 2
18+
config.explorer.rollout_model.engine_num = 2
19+
config.explorer.rollout_model.tensor_parallel_size = 2
20+
config.explorer.auxiliary_models.append(
21+
InferenceModelConfig(model_path="Qwen/Qwen3-32B", tensor_parallel_size=4, engine_num=1),
22+
)
1523
config.check_and_update()
1624
self.assertIsNotNone(config.trainer.trainer_config)
17-
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2)
25+
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 8)
1826
self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1)
1927
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.project)
2028
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.name)

trinity/cli/launcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
from pathlib import Path
6+
from pprint import pprint
67

78
import ray
89

@@ -157,6 +158,7 @@ def activate_data_module(data_workflow_url: str, config_path: str):
157158
def run(config_path: str, dlc: bool = False):
158159
config = load_config(config_path)
159160
config.check_and_update()
161+
pprint(config)
160162
# try to activate data module
161163
data_processor_config = config.data_processor
162164
if data_processor_config.data_workflow_url and (

trinity/common/config.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ class FormatConfig:
5353

5454
@dataclass
5555
class GenerationConfig:
56-
# repeat each task for `n` times (for GPRO-like algorithms)
57-
n: int = 1
5856
temperature: float = 1.0
5957
top_p: float = 1.0
6058
top_k: int = -1
6159
logprobs: int = 0 # vLLM return `logprobs + 1` elements
60+
# repeat each task for `n` times (for GPRO-like algorithms)
61+
# this field will be automatically set to `algorithm.repeat_times` in
62+
# `buffer.explorer_input.taskset.rollout_args`
63+
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
64+
n: int = 1
6265

6366

6467
@dataclass
@@ -67,7 +70,6 @@ class StorageConfig:
6770

6871
name: str = ""
6972
storage_type: StorageType = StorageType.FILE
70-
algorithm_type: Optional[AlgorithmType] = None # automatically set
7173
path: Optional[str] = None
7274

7375
# used for StorageType.FILE
@@ -76,13 +78,20 @@ class StorageConfig:
7678
format: FormatConfig = field(default_factory=FormatConfig)
7779
index: int = 0
7880

79-
# used for algorithm_type is None
80-
task_type: TaskType = TaskType.EXPLORE # automatically set
81+
# used for rollout tasks
8182
default_workflow_type: Optional[str] = None
8283
default_reward_fn_type: Optional[str] = None
83-
total_epochs: int = 1 # automatically set
8484
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
8585

86+
# ! DO NOT SET, automatically set from algorithm.algorithm_type
87+
algorithm_type: Optional[AlgorithmType] = None
88+
89+
# ! DO NOT SET, automatically set from buffer.total_epochs
90+
total_epochs: int = 1 # automatically set
91+
92+
# ! DO NOT SET, automatically set corresponding to train/eval
93+
task_type: TaskType = TaskType.EXPLORE
94+
8695

8796
@dataclass
8897
class DataProcessorConfig:
@@ -124,8 +133,10 @@ class ModelConfig:
124133

125134
@dataclass
126135
class InferenceModelConfig:
127-
# For Rollout Model: automatically set from config.model.model_path
136+
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
128137
model_path: str = ""
138+
139+
# support `vllm` or `vllm_async`,
129140
engine_type: str = "vllm_async"
130141
engine_num: int = 1
131142
tensor_parallel_size: int = 1
@@ -136,15 +147,22 @@ class InferenceModelConfig:
136147
gpu_memory_utilization: float = 0.9
137148
dtype: str = "bfloat16"
138149
seed: int = 42
150+
151+
# if not set, use `model.max_prompt_tokens`
139152
max_prompt_tokens: Optional[int] = None
153+
# if not set, use `model.max_response_tokens`
140154
max_response_tokens: Optional[int] = None
155+
141156
# override chat template in model
142157
chat_template: Optional[str] = None
158+
143159
# For Qwen3
144160
enable_thinking: bool = False
161+
145162
# For OpenAI API
146163
enable_openai_api: bool = False
147-
# DO NOT SET this field
164+
165+
# ! DO NOT SET
148166
bundle_indices: str = ""
149167

150168

@@ -209,7 +227,7 @@ class BufferConfig:
209227
max_retry_times: int = 3
210228
max_retry_interval: int = 1
211229

212-
# for experience construct, DO NOT SET
230+
# ! DO NOT SET FOLLOWING FIELDS
213231
read_batch_size: int = 1 # automatically set
214232
tokenizer_path: Optional[str] = None # automatically set
215233
pad_token_id: Optional[int] = None # automatically set
@@ -260,8 +278,7 @@ class TrainerConfig:
260278
class MonitorConfig:
261279
# TODO: support multiple monitors (List[MonitorType])
262280
monitor_type: MonitorType = MonitorType.WANDB
263-
# ! DO NOT SET
264-
# the root directory for monitor cache and meta files, automatically generated
281+
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
265282
cache_dir: str = ""
266283

267284

@@ -288,12 +305,12 @@ class Config:
288305
mode: str = "both" # `explore`, `train`, `both` or `bench`
289306
project: str = "Trinity-RFT"
290307
name: str = "rft"
291-
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
292308
# the root dir for checkpoints
293309
checkpoint_root_dir: str = ""
294-
# DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
310+
# ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
295311
checkpoint_job_dir: str = ""
296312

313+
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
297314
data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig)
298315
model: ModelConfig = field(default_factory=ModelConfig)
299316
cluster: ClusterConfig = field(default_factory=ClusterConfig)

trinity/trainer/verl_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def __init__(
7373
global_config: Config,
7474
):
7575
train_config = global_config.trainer
76-
pprint(train_config.trainer_config)
7776
config = OmegaConf.structured(train_config.trainer_config)
7877
# download the checkpoint from hdfs
7978
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

0 commit comments

Comments
 (0)