Skip to content

Commit 552fe2d

Browse files
committed
1. Implement serial save.
2. No longer set `max_model_len` from model config.json
1 parent 894858b commit 552fe2d

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

trinity/common/config.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -843,21 +843,7 @@ def _check_model(self) -> None:
843843
f"`max_model_len` is set to {model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`."
844844
)
845845
else:
846-
from transformers import AutoConfig, AutoTokenizer
847-
from transformers.tokenization_utils_base import LARGE_INTEGER
848-
849-
tokenizer = AutoTokenizer.from_pretrained(model.model_path)
850-
config = AutoConfig.from_pretrained(model.model_path)
851-
max_model_len = min(
852-
getattr(tokenizer, "model_max_length", LARGE_INTEGER),
853-
getattr(config, "max_position_embeddings", LARGE_INTEGER),
854-
)
855-
if max_model_len >= LARGE_INTEGER:
856-
max_model_len = MAX_MODEL_LEN
857-
logger.warning(
858-
f"Failed to get `max_model_len` from model {model.model_path}, use {MAX_MODEL_LEN} instead."
859-
)
860-
model.max_model_len = max_model_len
846+
raise ValueError("Unable to determine `max_model_len`, please set it manually.")
861847

862848
# both max_prompt_tokens and max_response_tokens are None
863849
if model.max_prompt_tokens is None and model.max_response_tokens is None:

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _save_with_thread(
115115
thread.join()
116116

117117
def _save():
118+
ray.get(self.checkpoint_monitor.notify_started.remote())
118119
torch.save(obj, path)
119120
log_with_rank(
120121
f"Saved {prefix} to {os.path.abspath(path)}",
@@ -357,6 +358,7 @@ def save_checkpoint( # noqa: C901
357358
self._save_model_thread.join()
358359

359360
def _save_model():
361+
ray.get(self.checkpoint_monitor.notify_started.remote())
360362
save_model.save_pretrained(hf_local_path, state_dict=state_dict)
361363
log_with_rank(
362364
f"Saved hf_model to {os.path.abspath(hf_local_path)}",

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _save_state_dict(self, local_path, global_step):
125125

126126
def finalize_save_fn():
127127
# Rank 0 uploads checkpoint to HDFS if hdfs_path is provided
128+
ray.get(self.checkpoint_monitor.notify_started.remote())
128129
log_with_rank(
129130
f"Dist checkpointing save completed for {dist_checkpoint_path}",
130131
rank=self.rank,

trinity/trainer/verl_trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
Modified from verl/trainer/ppo/ray_trainer.py
55
"""
6+
import asyncio
67
import os
78
import sys
89
from collections import defaultdict
@@ -57,6 +58,9 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
5758
self.latest_checkpoint_step = 0
5859
self.latest_state_dict_step = 0
5960

61+
self.condition = asyncio.Condition()
62+
self.saving_count = 0
63+
6064
def update_latest_checkpoint_step(self, step: int):
6165
assert step >= self.latest_checkpoint_step
6266
if step == self.latest_checkpoint_step:
@@ -87,7 +91,7 @@ def update_latest_state_dict_step(self, step: int):
8791
with open(self.local_latest_state_dict_iteration, "w") as f:
8892
f.write(str(step))
8993

90-
def register_thread_count(
94+
async def register_thread_count(
9195
self,
9296
step: int,
9397
*,
@@ -99,7 +103,7 @@ def register_thread_count(
99103
if checkpoint_thread_count != 0:
100104
self.checkpoint_counter[step] += checkpoint_thread_count
101105

102-
def monitor_step(self, step: int, is_state_dict: bool = False):
106+
async def monitor_step(self, step: int, is_state_dict: bool = False):
103107
if is_state_dict:
104108
self.state_dict_steps.add(step)
105109
if self.state_dict_counter[step] == 0:
@@ -109,7 +113,16 @@ def monitor_step(self, step: int, is_state_dict: bool = False):
109113
if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0:
110114
self.update_latest_checkpoint_step(step)
111115

112-
def notify_finished(self, step: int, is_state_dict: bool = False):
116+
async def notify_started(self):
117+
async with self.condition:
118+
while self.saving_count > 0:
119+
await self.condition.wait_for(lambda: self.saving_count == 0)
120+
self.saving_count += 1
121+
122+
async def notify_finished(self, step: int, is_state_dict: bool = False):
123+
async with self.condition:
124+
self.saving_count -= 1
125+
self.condition.notify()
113126
if is_state_dict:
114127
self.state_dict_counter[step] -= 1
115128
if (

0 commit comments

Comments
 (0)