Skip to content

Commit 9eb25fa

Browse files
committed
Update config_manager.py
1 parent 5829917 commit 9eb25fa

File tree

4 files changed

+74
-27
lines changed

4 files changed

+74
-27
lines changed

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class ExplorerConfig:
172172
backend: str = "nccl"
173173
use_ray: bool = False
174174
gpu_memory_utilization: float = 0.9
175-
enable_chunked_prefil: bool = False
175+
enable_chunked_prefill: bool = False
176176
use_v1: bool = True
177177
bundle_indices: str = "" # DO NOT SET this field
178178

trinity/common/models/vllm_async_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
task="generate",
8181
disable_log_requests=True,
8282
gpu_memory_utilization=config.explorer.gpu_memory_utilization,
83-
enable_chunked_prefill=config.explorer.enable_chunked_prefil,
83+
enable_chunked_prefill=config.explorer.enable_chunked_prefill,
8484
# max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
8585
)
8686
self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args)

trinity/common/models/vllm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, config: Config, **kwargs):
6565
dtype=config.explorer.dtype,
6666
trust_remote_code=True,
6767
gpu_memory_utilization=config.explorer.gpu_memory_utilization,
68-
enable_chunked_prefill=config.explorer.enable_chunked_prefil,
68+
enable_chunked_prefill=config.explorer.enable_chunked_prefill,
6969
# max_num_batched_tokens=256,
7070
**kwargs,
7171
)

trinity/manager/config_manager.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _init_default_config(self):
7272
"_not_dpo_storage_type": StorageType.QUEUE.value,
7373
"storage_type": StorageType.QUEUE.value,
7474
"train_dataset_path": "",
75-
"max_retry_times": 3,
75+
"buffer_max_retry_times": 3,
7676
"max_retry_interval": 1,
7777
"dpo_dataset_train_split": "train",
7878
"dpo_dataset_prompt_type": PromptType.MESSAGES.value,
@@ -88,31 +88,37 @@ def _init_default_config(self):
8888
# Explorer and Sync Configs
8989
"engine_type": "vllm_async",
9090
"engine_num": 2,
91-
"tensor_parallel_size": 1,
91+
"runner_num": 32,
9292
"_grouped_adv_repeat_times": 2,
9393
"_not_grouped_adv_repeat_times": 1,
9494
"repeat_times": 1,
95-
"_not_dpo_sync_method": SyncMethod.NCCL.value,
96-
"sync_method": SyncMethod.NCCL.value,
97-
"sync_interval": 10,
98-
"sync_timeout": 1200,
99-
"runner_num": 32,
100-
"max_pending_requests": 32,
101-
"max_waiting_steps": 4,
95+
"eval_interval": 1000,
96+
"tensor_parallel_size": 1,
97+
"enable_prefix_caching": False,
98+
"enforce_eager": True,
10299
"dtype": "bfloat16",
103-
"backend": "nccl",
104100
"temperature": 1.0,
105101
"top_p": 1.0,
106102
"top_k": -1,
107103
"seed": 42,
108104
"logprobs": 0,
109-
"enable_prefix_caching": False,
110-
"enforce_eager": True,
105+
"backend": "nccl",
106+
"use_ray": False,
107+
"gpu_memory_utilization": 0.9,
108+
"enable_chunked_prefill": False,
109+
"max_pending_requests": 32,
110+
"max_waiting_steps": 4,
111+
"max_timeout": 900,
112+
"explorer_max_retry_times": 2,
113+
# Synchronizer Configs
114+
"_not_dpo_sync_method": SyncMethod.NCCL.value,
115+
"sync_method": SyncMethod.NCCL.value,
116+
"sync_interval": 10,
117+
"sync_timeout": 1200,
111118
# Trainer Configs
112119
"trainer_type": "verl",
113120
"algorithm_type": AlgorithmType.PPO.value,
114121
"sft_warmup_steps": 0,
115-
"eval_interval": 1000,
116122
"_nccl_save_interval": 100,
117123
"save_interval": 100,
118124
# veRL Trainer Configs
@@ -370,8 +376,8 @@ def _set_train_dataset_path(self): # TODO
370376
self.unfinished_fields.add("train_dataset_path")
371377
st.warning("Please input train dataset path.")
372378

373-
def _set_max_retry_times(self):
374-
st.number_input("Max Retry Times", key="max_retry_times", min_value=1)
379+
def _set_buffer_max_retry_times(self):
380+
st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1)
375381

376382
def _set_max_retry_interval(self):
377383
st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1)
@@ -613,11 +619,28 @@ def _set_logprobs(self):
613619
st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20)
614620

615621
def _set_enable_prefix_caching(self):
616-
st.checkbox("Enable Prefix Caching", key="enable_prefix_caching")
622+
st.checkbox("Prefix Caching", key="enable_prefix_caching")
617623

618624
def _set_enforce_eager(self):
619625
st.checkbox("Enforce Eager", key="enforce_eager")
620626

627+
def _set_use_ray(self):
628+
st.checkbox("Use Ray", key="use_ray")
629+
630+
def _set_gpu_memory_utilization(self):
631+
st.number_input(
632+
"GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0
633+
)
634+
635+
def _set_enable_chunked_prefill(self):
636+
st.checkbox("Chunked Prefill", key="enable_chunked_prefill")
637+
638+
def _set_max_timeout(self):
639+
st.number_input("Max Timeout", key="max_timeout", min_value=0)
640+
641+
def _set_explorer_max_retry_times(self):
642+
st.number_input("Explorer Max Retry Times", key="explorer_max_retry_times", min_value=0)
643+
621644
def _set_trainer_type(self):
622645
st.selectbox("Trainer Type", ["verl"], key="trainer_type")
623646

@@ -1079,7 +1102,7 @@ def _expert_buffer_part(self):
10791102

10801103
self.buffer_advanced_tab = st.expander("Advanced Config")
10811104
with self.buffer_advanced_tab:
1082-
self._set_configs_with_st_columns(["max_retry_times", "max_retry_interval"])
1105+
self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"])
10831106

10841107
self._set_sft_warmup_dataset_path()
10851108
self._set_sft_warmup_dataset_args()
@@ -1094,12 +1117,22 @@ def _expert_explorer_part(self):
10941117

10951118
with st.expander("Advanced Config"):
10961119
self._set_configs_with_st_columns(
1097-
["runner_num", "max_pending_requests", "max_waiting_steps", "dtype"]
1120+
["runner_num", "temperature", "top_p", "top_k", "seed", "logprobs"]
10981121
)
10991122

1100-
self._set_configs_with_st_columns(["backend", "temperature", "seed", "logprobs"])
1123+
self._set_configs_with_st_columns(["dtype", "backend", "gpu_memory_utilization"])
1124+
self._set_configs_with_st_columns(
1125+
[
1126+
"max_pending_requests",
1127+
"max_waiting_steps",
1128+
"max_timeout",
1129+
"explorer_max_retry_times",
1130+
]
1131+
)
11011132

1102-
self._set_configs_with_st_columns(["enable_prefix_caching", "enforce_eager"])
1133+
self._set_configs_with_st_columns(
1134+
["enable_prefix_caching", "enforce_eager", "use_ray", "enable_chunked_prefill"]
1135+
)
11031136

11041137
def _expert_trainer_part(self):
11051138
self._set_configs_with_st_columns( # TODO: may add `trainer_type`
@@ -1442,6 +1475,12 @@ def generate_config(self):
14421475
else:
14431476
trainer_n_gpus_per_node = st.session_state["gpu_per_node"]
14441477

1478+
critic_model_path = (
1479+
st.session_state["critic_model_path"].strip()
1480+
if st.session_state["critic_model_path"].strip()
1481+
else st.session_state["model_path"]
1482+
)
1483+
14451484
if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
14461485
train_dataset_path = (
14471486
st.session_state["train_dataset_path"].strip()
@@ -1495,6 +1534,7 @@ def generate_config(self):
14951534
},
14961535
"model": {
14971536
"model_path": st.session_state["model_path"],
1537+
"critic_model_path": critic_model_path,
14981538
"max_prompt_tokens": st.session_state["max_prompt_tokens"],
14991539
"max_response_tokens": st.session_state["max_response_tokens"],
15001540
"checkpoint_path": st.session_state["checkpoint_path"],
@@ -1504,18 +1544,16 @@ def generate_config(self):
15041544
"gpu_per_node": st.session_state["gpu_per_node"],
15051545
},
15061546
"buffer": {
1507-
"max_retry_times": st.session_state["max_retry_times"],
1547+
"max_retry_times": st.session_state["buffer_max_retry_times"],
15081548
"max_retry_interval": st.session_state["max_retry_interval"],
15091549
"train_dataset": {
15101550
"name": "experience_buffer", # TODO
15111551
"storage_type": st.session_state["storage_type"],
1512-
"algorithm_type": st.session_state["algorithm_type"],
15131552
"path": train_dataset_path,
15141553
},
15151554
"sft_warmup_dataset": {
15161555
"name": "sft_warmup_dataset",
15171556
"storage_type": sft_storage_type,
1518-
"algorithm_type": AlgorithmType.SFT.value,
15191557
"path": st.session_state["sft_warmup_dataset_path"],
15201558
"kwargs": {
15211559
"train_split": st.session_state["sft_warmup_train_split"],
@@ -1530,18 +1568,27 @@ def generate_config(self):
15301568
"engine_type": st.session_state["engine_type"],
15311569
"engine_num": st.session_state["engine_num"],
15321570
"runner_num": st.session_state["runner_num"],
1571+
"repeat_times": st.session_state["repeat_times"],
1572+
# "chat_template": None, # TODO: add chat template
15331573
"eval_interval": st.session_state["eval_interval"],
15341574
"tensor_parallel_size": st.session_state["tensor_parallel_size"],
15351575
"enable_prefix_caching": st.session_state["enable_prefix_caching"],
15361576
"enforce_eager": st.session_state["enforce_eager"],
15371577
"dtype": st.session_state["dtype"],
15381578
"temperature": st.session_state["temperature"],
1579+
"top_p": st.session_state["top_p"], # TODO
1580+
"top_k": st.session_state["top_k"], # TODO
15391581
"seed": st.session_state["seed"],
15401582
"logprobs": st.session_state["logprobs"],
1541-
"repeat_times": st.session_state["repeat_times"],
15421583
"backend": st.session_state["backend"],
1584+
"use_ray": st.session_state["use_ray"], # TODO
1585+
"gpu_memory_utilization": st.session_state["gpu_memory_utilization"], # TODO
1586+
"enable_chunked_prefill": st.session_state["enable_chunked_prefill"], # TODO
1587+
"use_v1": True,
15431588
"max_pending_requests": st.session_state["max_pending_requests"],
15441589
"max_waiting_steps": st.session_state["max_waiting_steps"],
1590+
"max_timeout": st.session_state["max_timeout"], # TODO
1591+
"max_retry_times": st.session_state["explorer_max_retry_times"], # TODO
15451592
},
15461593
"synchronizer": {
15471594
"sync_method": st.session_state["sync_method"],

0 commit comments

Comments
 (0)