Skip to content

Commit 4fac567

Browse files
committed
add save strategy
1 parent 43f162d commit 4fac567

File tree

7 files changed

+56
-10
lines changed

7 files changed

+56
-10
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ model:
160160

161161
- `model_path`: Path to the model being trained.
162162
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
163-
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration.
163+
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will default to `max_prompt_tokens` + `max_response_tokens`. However, if either `max_prompt_tokens` or `max_response_tokens` is not set, we will raise an error.
164164
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
165165
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
166166
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ model:
160160

161161
- `model_path`: 被训练模型的路径。
162162
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
163-
- `max_model_len`: 该模型所支持的单个序列最大 token 数。
163+
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如果未设置该值,则会尝试将其默认设为 `max_prompt_tokens + max_response_tokens`。但如果 `max_prompt_tokens` 或 `max_response_tokens` 中有任何一个未设置,代码将会报错
164164
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
165165
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
166166
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。

trinity/common/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PLUGIN_DIRS_ENV_VAR,
2020
TRAINER_NAME,
2121
PromptType,
22+
SaveStrategy,
2223
StorageType,
2324
SyncMethod,
2425
SyncStyle,
@@ -470,6 +471,8 @@ class TrainerConfig:
470471
actor_grad_clip: Optional[float] = None
471472
# TODO: extract more train-related params from underlying trainer engine
472473

474+
save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
475+
473476
# Only one needs to be set for `trainer_config` and `trainer_config_path`
474477
trainer_config: Any = field(default_factory=dict)
475478
trainer_config_path: str = ""

trinity/common/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,10 @@ class SyncStyle(CaseInsensitiveEnum):
104104
FIXED = "fixed"
105105
DYNAMIC_BY_TRAINER = "dynamic_by_trainer"
106106
DYNAMIC_BY_EXPLORER = "dynamic_by_explorer"
107+
108+
109+
class SaveStrategy(CaseInsensitiveEnum):
110+
SINGLE_THREAD = "single_thread"
111+
SINGLE_PROCESS = "single_process"
112+
SINGLE_NODE = "single_node"
113+
UNRESTRICTED = "unrestricted"

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def _save_with_thread(
115115
thread.join()
116116

117117
def _save():
118-
ray.get(self.checkpoint_monitor.notify_started.remote())
118+
runtime_context = ray.get_runtime_context()
119+
node_id = runtime_context.get_node_id()
120+
job_id = runtime_context.get_job_id()
121+
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
119122
torch.save(obj, path)
120123
log_with_rank(
121124
f"Saved {prefix} to {os.path.abspath(path)}",
@@ -358,7 +361,14 @@ def save_checkpoint( # noqa: C901
358361
self._save_model_thread.join()
359362

360363
def _save_model():
361-
ray.get(self.checkpoint_monitor.notify_started.remote())
364+
runtime_context = ray.get_runtime_context()
365+
node_id = runtime_context.get_node_id()
366+
job_id = runtime_context.get_job_id()
367+
ray.get(
368+
self.checkpoint_monitor.notify_started.remote(
369+
node_id=node_id, job_id=job_id
370+
)
371+
)
362372
save_model.save_pretrained(hf_local_path, state_dict=state_dict)
363373
log_with_rank(
364374
f"Saved hf_model to {os.path.abspath(hf_local_path)}",

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def _save_state_dict(self, local_path, global_step):
125125

126126
def finalize_save_fn():
127127
# Rank 0 uploads checkpoint to HDFS if hdfs_path is provided
128-
ray.get(self.checkpoint_monitor.notify_started.remote())
128+
runtime_context = ray.get_runtime_context()
129+
node_id = runtime_context.get_node_id()
130+
job_id = runtime_context.get_job_id()
131+
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
129132
log_with_rank(
130133
f"Dist checkpointing save completed for {dist_checkpoint_path}",
131134
rank=self.rank,

trinity/trainer/verl_trainer.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@
3434
from trinity.algorithm.algorithm import ALGORITHM_TYPE
3535
from trinity.algorithm.utils import prefix_metrics
3636
from trinity.common.config import Config
37+
from trinity.common.constants import SaveStrategy
3738
from trinity.common.experience import Experiences
3839
from trinity.trainer.trainer import TrainEngineWrapper
3940
from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto
4041
from trinity.utils.log import get_logger
4142

4243

4344
class CheckpointMonitor:
44-
def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
45+
def __init__(
46+
self, save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None
47+
):
4548
self.logger = get_logger("checkpoint_monitor", in_ray_actor=True)
4649
self.default_local_dir = default_local_dir
4750
self.default_hdfs_dir = default_hdfs_dir
@@ -58,7 +61,9 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
5861
self.latest_checkpoint_step = 0
5962
self.latest_state_dict_step = 0
6063

64+
self.save_strategy = save_strategy
6165
self.condition = asyncio.Condition()
66+
self.current_identifier = 0
6267
self.saving_count = 0
6368

6469
def update_latest_checkpoint_step(self, step: int):
@@ -113,16 +118,28 @@ async def monitor_step(self, step: int, is_state_dict: bool = False):
113118
if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0:
114119
self.update_latest_checkpoint_step(step)
115120

116-
async def notify_started(self):
121+
async def notify_started(self, node_id: str, job_id: str):
122+
if self.save_strategy == SaveStrategy.SINGLE_THREAD:
123+
identifier = self.current_identifier + 1
124+
elif self.save_strategy == SaveStrategy.SINGLE_PROCESS:
125+
identifier = f"{node_id}_{job_id}"
126+
elif self.save_strategy == SaveStrategy.SINGLE_NODE:
127+
identifier = node_id
128+
elif self.save_strategy == SaveStrategy.UNRESTRICTED:
129+
return
130+
else:
131+
raise ValueError(f"Invalid save strategy: {self.save_strategy}")
132+
117133
async with self.condition:
118-
while self.saving_count > 0:
134+
if identifier != self.current_identifier and self.saving_count > 0:
119135
await self.condition.wait_for(lambda: self.saving_count == 0)
136+
self.current_identifier = identifier
120137
self.saving_count += 1
121138

122139
async def notify_finished(self, step: int, is_state_dict: bool = False):
123140
async with self.condition:
124141
self.saving_count -= 1
125-
self.condition.notify()
142+
self.condition.notify_all()
126143
if is_state_dict:
127144
self.state_dict_counter[step] -= 1
128145
if (
@@ -144,6 +161,7 @@ async def notify_finished(self, step: int, is_state_dict: bool = False):
144161
def get_actor(
145162
cls,
146163
namespace: str,
164+
save_strategy: Optional[SaveStrategy] = None,
147165
default_local_dir: Optional[str] = None,
148166
default_hdfs_dir: Optional[str] = None,
149167
):
@@ -154,7 +172,11 @@ def get_actor(
154172
namespace=namespace,
155173
get_if_exists=True,
156174
)
157-
.remote(default_local_dir=default_local_dir, default_hdfs_dir=default_hdfs_dir)
175+
.remote(
176+
save_strategy=save_strategy,
177+
default_local_dir=default_local_dir,
178+
default_hdfs_dir=default_hdfs_dir,
179+
)
158180
)
159181

160182

@@ -204,6 +226,7 @@ def __init__(
204226

205227
self.checkpoint_monitor = CheckpointMonitor.get_actor(
206228
namespace=global_config.synchronizer.ray_namespace,
229+
save_strategy=global_config.trainer.save_strategy,
207230
default_local_dir=config.trainer.default_local_dir,
208231
default_hdfs_dir=config.trainer.default_hdfs_dir,
209232
)

0 commit comments

Comments
 (0)