Skip to content

Commit d5db95a

Browse files
authored
Implement serial saving (agentscope-ai#322)
1 parent e29bd29 commit d5db95a

File tree

10 files changed

+88
-35
lines changed

10 files changed

+88
-35
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ model:
160160

161161
- `model_path`: Path to the model being trained.
162162
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
163-
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration.
163+
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised.
164164
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
165165
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
166166
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
@@ -405,6 +405,7 @@ trainer:
405405
trainer_type: 'verl'
406406
save_interval: 100
407407
total_steps: 1000
408+
save_strategy: "unrestricted"
408409
trainer_config: null
409410
trainer_config_path: ''
410411
```
@@ -413,6 +414,11 @@ trainer:
413414
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
414415
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
415416
- `total_steps`: Total number of training steps.
417+
- `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows:
418+
- `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially.
419+
- `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially.
420+
- `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially.
421+
- `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously.
416422
- `trainer_config`: The trainer configuration provided inline.
417423
- `trainer_config_path`: The path to the trainer configuration file. Only one of `trainer_config_path` and `trainer_config` should be specified.
418424

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ model:
160160

161161
- `model_path`: 被训练模型的路径。
162162
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
163-
- `max_model_len`: 该模型所支持的单个序列最大 token 数。
163+
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误
164164
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
165165
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
166166
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
@@ -405,6 +405,7 @@ trainer:
405405
trainer_type: 'verl'
406406
save_interval: 100
407407
total_steps: 1000
408+
save_strategy: "unrestricted"
408409
trainer_config: null
409410
trainer_config_path: ''
410411
```
@@ -413,6 +414,11 @@ trainer:
413414
- `trainer_type`: trainer 后端实现。目前仅支持 `verl`。
414415
- `save_interval`: 保存模型检查点的频率(步)。
415416
- `total_steps`: 总训练步数。
417+
- `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下:
418+
- `single_thread`:整个系统中,仅允许一个线程进行模型保存,不同保存线程之间串行执行。
419+
- `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。
420+
- `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。
421+
- `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。
416422
- `trainer_config`: 内联提供的 trainer 配置。
417423
- `trainer_config_path`: trainer 配置文件的路径。`trainer_config_path` 和 `trainer_config` 只能指定其一。
418424

tests/cli/launcher_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
class TestLauncherMain(unittest.TestCase):
3232
def setUp(self):
33+
if multiprocessing.get_start_method(allow_none=True) != "spawn":
34+
multiprocessing.set_start_method("spawn", force=True)
3335
self._orig_argv = sys.argv.copy()
3436
self.config = get_template_config()
3537
self.config.checkpoint_root_dir = get_checkpoint_path()

tests/common/vllm_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,11 @@ def print_debug(*args):
9494
"repeat_times",
9595
"enable_history",
9696
"use_async",
97-
"max_model_len",
9897
),
9998
[
100-
(2, 2, 2, True, False, None),
101-
(1, 2, 1, False, True, None),
102-
(2, 1, 3, True, True, None),
99+
(2, 2, 2, True, False),
100+
(1, 2, 1, False, True),
101+
(2, 1, 3, True, True),
103102
],
104103
)
105104
class ModelWrapperTest(RayUnittestBaseAysnc):
@@ -108,7 +107,6 @@ def setUp(self):
108107
self.config = get_template_config()
109108
self.config.mode = "explore"
110109
self.config.model.model_path = get_model_path()
111-
self.config.model.max_model_len = self.max_model_len
112110
self.config.explorer.rollout_model.engine_num = self.engine_num
113111
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
114112
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
@@ -189,12 +187,7 @@ async def test_generate(
189187
"content": results[0].response_text,
190188
}
191189
)
192-
if self.max_model_len is not None:
193-
with self.assertRaises(ValueError):
194-
exp = self.model_wrapper.convert_messages_to_experience(messages)
195-
return
196-
else:
197-
exp = self.model_wrapper.convert_messages_to_experience(messages)
190+
exp = self.model_wrapper.convert_messages_to_experience(messages)
198191
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
199192
result_dict = tokenizer.apply_chat_template(
200193
messages,

tests/explorer/workflow_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def setUp(self):
466466
self.config = get_template_config()
467467
self.config.mode = "explore"
468468
self.config.model.model_path = get_model_path()
469-
self.config.model.max_model_len = None # self.max_model_len
470469
self.config.explorer.rollout_model.engine_num = 1 # self.engine_num
471470
self.config.explorer.rollout_model.tensor_parallel_size = 1 # self.tensor_parallel_size
472471
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE

trinity/common/config.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
LOG_DIR_ENV_VAR,
1717
LOG_LEVEL_ENV_VAR,
1818
LOG_NODE_IP_ENV_VAR,
19-
MAX_MODEL_LEN,
2019
PLUGIN_DIRS_ENV_VAR,
2120
TRAINER_NAME,
2221
PromptType,
22+
SaveStrategy,
2323
StorageType,
2424
SyncMethod,
2525
SyncStyle,
@@ -471,6 +471,8 @@ class TrainerConfig:
471471
actor_grad_clip: Optional[float] = None
472472
# TODO: extract more train-related params from underlying trainer engine
473473

474+
save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
475+
474476
# Only one needs to be set for `trainer_config` and `trainer_config_path`
475477
trainer_config: Any = field(default_factory=dict)
476478
trainer_config_path: str = ""
@@ -843,21 +845,7 @@ def _check_model(self) -> None:
843845
f"`max_model_len` is set to {model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`."
844846
)
845847
else:
846-
from transformers import AutoConfig, AutoTokenizer
847-
from transformers.tokenization_utils_base import LARGE_INTEGER
848-
849-
tokenizer = AutoTokenizer.from_pretrained(model.model_path)
850-
config = AutoConfig.from_pretrained(model.model_path)
851-
max_model_len = min(
852-
getattr(tokenizer, "model_max_length", LARGE_INTEGER),
853-
getattr(config, "max_position_embeddings", LARGE_INTEGER),
854-
)
855-
if max_model_len >= LARGE_INTEGER:
856-
max_model_len = MAX_MODEL_LEN
857-
logger.warning(
858-
f"Failed to get `max_model_len` from model {model.model_path}, use {MAX_MODEL_LEN} instead."
859-
)
860-
model.max_model_len = max_model_len
848+
raise ValueError("Unable to determine `max_model_len`, please set it manually.")
861849

862850
# both max_prompt_tokens and max_response_tokens are None
863851
if model.max_prompt_tokens is None and model.max_response_tokens is None:

trinity/common/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,10 @@ class SyncStyle(CaseInsensitiveEnum):
104104
FIXED = "fixed"
105105
DYNAMIC_BY_TRAINER = "dynamic_by_trainer"
106106
DYNAMIC_BY_EXPLORER = "dynamic_by_explorer"
107+
108+
109+
class SaveStrategy(CaseInsensitiveEnum):
110+
SINGLE_THREAD = "single_thread"
111+
SINGLE_PROCESS = "single_process"
112+
SINGLE_NODE = "single_node"
113+
UNRESTRICTED = "unrestricted"

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def _save_with_thread(
115115
thread.join()
116116

117117
def _save():
118+
runtime_context = ray.get_runtime_context()
119+
node_id = runtime_context.get_node_id()
120+
job_id = runtime_context.get_job_id()
121+
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
118122
torch.save(obj, path)
119123
log_with_rank(
120124
f"Saved {prefix} to {os.path.abspath(path)}",
@@ -357,6 +361,14 @@ def save_checkpoint( # noqa: C901
357361
self._save_model_thread.join()
358362

359363
def _save_model():
364+
runtime_context = ray.get_runtime_context()
365+
node_id = runtime_context.get_node_id()
366+
job_id = runtime_context.get_job_id()
367+
ray.get(
368+
self.checkpoint_monitor.notify_started.remote(
369+
node_id=node_id, job_id=job_id
370+
)
371+
)
360372
save_model.save_pretrained(hf_local_path, state_dict=state_dict)
361373
log_with_rank(
362374
f"Saved hf_model to {os.path.abspath(hf_local_path)}",

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def _save_state_dict(self, local_path, global_step):
125125

126126
def finalize_save_fn():
127127
# Rank 0 uploads checkpoint to HDFS if hdfs_path is provided
128+
runtime_context = ray.get_runtime_context()
129+
node_id = runtime_context.get_node_id()
130+
job_id = runtime_context.get_job_id()
131+
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
128132
log_with_rank(
129133
f"Dist checkpointing save completed for {dist_checkpoint_path}",
130134
rank=self.rank,

trinity/trainer/verl_trainer.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
Modified from verl/trainer/ppo/ray_trainer.py
55
"""
6+
import asyncio
67
import os
78
import sys
89
from collections import defaultdict
@@ -33,14 +34,17 @@
3334
from trinity.algorithm.algorithm import ALGORITHM_TYPE
3435
from trinity.algorithm.utils import prefix_metrics
3536
from trinity.common.config import Config
37+
from trinity.common.constants import SaveStrategy
3638
from trinity.common.experience import Experiences
3739
from trinity.trainer.trainer import TrainEngineWrapper
3840
from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto
3941
from trinity.utils.log import get_logger
4042

4143

4244
class CheckpointMonitor:
43-
def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
45+
def __init__(
46+
self, save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None
47+
):
4448
self.logger = get_logger("checkpoint_monitor", in_ray_actor=True)
4549
self.default_local_dir = default_local_dir
4650
self.default_hdfs_dir = default_hdfs_dir
@@ -57,6 +61,11 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
5761
self.latest_checkpoint_step = 0
5862
self.latest_state_dict_step = 0
5963

64+
self.save_strategy = save_strategy
65+
self.condition = asyncio.Condition()
66+
self.current_identifier = 0
67+
self.saving_count = 0
68+
6069
def update_latest_checkpoint_step(self, step: int):
6170
assert step >= self.latest_checkpoint_step
6271
if step == self.latest_checkpoint_step:
@@ -87,7 +96,7 @@ def update_latest_state_dict_step(self, step: int):
8796
with open(self.local_latest_state_dict_iteration, "w") as f:
8897
f.write(str(step))
8998

90-
def register_thread_count(
99+
async def register_thread_count(
91100
self,
92101
step: int,
93102
*,
@@ -99,7 +108,7 @@ def register_thread_count(
99108
if checkpoint_thread_count != 0:
100109
self.checkpoint_counter[step] += checkpoint_thread_count
101110

102-
def monitor_step(self, step: int, is_state_dict: bool = False):
111+
async def monitor_step(self, step: int, is_state_dict: bool = False):
103112
if is_state_dict:
104113
self.state_dict_steps.add(step)
105114
if self.state_dict_counter[step] == 0:
@@ -109,7 +118,28 @@ def monitor_step(self, step: int, is_state_dict: bool = False):
109118
if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0:
110119
self.update_latest_checkpoint_step(step)
111120

112-
def notify_finished(self, step: int, is_state_dict: bool = False):
121+
async def notify_started(self, node_id: str, job_id: str):
122+
if self.save_strategy == SaveStrategy.SINGLE_THREAD:
123+
identifier = self.current_identifier + 1
124+
elif self.save_strategy == SaveStrategy.SINGLE_PROCESS:
125+
identifier = f"{node_id}_{job_id}"
126+
elif self.save_strategy == SaveStrategy.SINGLE_NODE:
127+
identifier = node_id
128+
elif self.save_strategy == SaveStrategy.UNRESTRICTED:
129+
return
130+
else:
131+
raise ValueError(f"Invalid save strategy: {self.save_strategy}")
132+
133+
async with self.condition:
134+
if identifier != self.current_identifier and self.saving_count > 0:
135+
await self.condition.wait_for(lambda: self.saving_count == 0)
136+
self.current_identifier = identifier
137+
self.saving_count += 1
138+
139+
async def notify_finished(self, step: int, is_state_dict: bool = False):
140+
async with self.condition:
141+
self.saving_count -= 1
142+
self.condition.notify_all()
113143
if is_state_dict:
114144
self.state_dict_counter[step] -= 1
115145
if (
@@ -131,6 +161,7 @@ def notify_finished(self, step: int, is_state_dict: bool = False):
131161
def get_actor(
132162
cls,
133163
namespace: str,
164+
save_strategy: Optional[SaveStrategy] = None,
134165
default_local_dir: Optional[str] = None,
135166
default_hdfs_dir: Optional[str] = None,
136167
):
@@ -141,7 +172,11 @@ def get_actor(
141172
namespace=namespace,
142173
get_if_exists=True,
143174
)
144-
.remote(default_local_dir=default_local_dir, default_hdfs_dir=default_hdfs_dir)
175+
.remote(
176+
save_strategy=save_strategy,
177+
default_local_dir=default_local_dir,
178+
default_hdfs_dir=default_hdfs_dir,
179+
)
145180
)
146181

147182

@@ -191,6 +226,7 @@ def __init__(
191226

192227
self.checkpoint_monitor = CheckpointMonitor.get_actor(
193228
namespace=global_config.synchronizer.ray_namespace,
229+
save_strategy=global_config.trainer.save_strategy,
194230
default_local_dir=config.trainer.default_local_dir,
195231
default_hdfs_dir=config.trainer.default_hdfs_dir,
196232
)

0 commit comments

Comments
 (0)