Skip to content

Commit a2708fd

Browse files
committed
bug fix in checkpoint monitor
1 parent cdfd70e commit a2708fd

File tree

4 files changed

+237
-204
lines changed

4 files changed

+237
-204
lines changed

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 45 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -98,79 +98,62 @@ def upload_state_dict(self, global_step: int):
9898
state_dict = self.model.state_dict()
9999
self._upload_state_dict(state_dict, global_step)
100100

101-
def _save_model(self, local_path, global_step):
102-
model_path = os.path.join(
103-
local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt"
101+
def _save_with_thread(
102+
self,
103+
obj,
104+
local_path: str,
105+
prefix: str,
106+
thread_name: str,
107+
global_step: int,
108+
is_state_dict: bool = False,
109+
):
110+
path = os.path.join(
111+
local_path, f"{prefix}_world_size_{self.world_size}_rank_{self.rank}.pt"
104112
)
105-
model_state_dict = self.model.state_dict()
106-
if self._model_state_dict_thread is not None:
107-
self._model_state_dict_thread.join()
113+
thread = getattr(self, thread_name)
114+
if thread is not None:
115+
thread.join()
108116

109-
def _save_model_state_dict():
110-
torch.save(model_state_dict, model_path)
117+
def _save():
118+
torch.save(obj, path)
111119
log_with_rank(
112-
f"Saved model to {os.path.abspath(model_path)}",
120+
f"Saved {prefix} to {os.path.abspath(path)}",
113121
rank=self.rank,
114122
logger=logger,
115123
)
116-
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step, True))
124+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step, is_state_dict))
125+
126+
thread = threading.Thread(
127+
target=_save,
128+
)
129+
thread.start()
130+
setattr(self, thread_name, thread)
117131

118-
self._model_state_dict_thread = threading.Thread(
119-
target=_save_model_state_dict,
132+
def _save_model(self, local_path, global_step):
133+
model_state_dict = self.model.state_dict()
134+
self._save_with_thread(
135+
model_state_dict, local_path, "model", "_model_state_dict_thread", global_step, True
120136
)
121-
self._model_state_dict_thread.start()
122137

123138
self.previous_state_dict_step = global_step
124139

125140
def _save_optimizer(self, local_path, global_step):
126-
optim_path = os.path.join(
127-
local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt"
128-
)
129141
optimizer_state_dict = self.optimizer.state_dict()
130-
if self._optimizer_state_dict_thread is not None:
131-
self._optimizer_state_dict_thread.join()
132-
133-
def _save_optimizer_state_dict():
134-
torch.save(optimizer_state_dict, optim_path)
135-
log_with_rank(
136-
f"Saved optim to {os.path.abspath(optim_path)}",
137-
rank=self.rank,
138-
logger=logger,
139-
)
140-
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step))
141-
142-
self._optimizer_state_dict_thread = threading.Thread(
143-
target=_save_optimizer_state_dict,
142+
self._save_with_thread(
143+
optimizer_state_dict, local_path, "optim", "_optimizer_state_dict_thread", global_step
144144
)
145-
self._optimizer_state_dict_thread.start()
146145

147146
def _save_extra_state(self, local_path, global_step):
148-
extra_path = os.path.join(
149-
local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt"
150-
)
151147
lr_scheduler_state_dict = (
152148
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
153149
)
154150
extra_state_dict = {
155151
"lr_scheduler": lr_scheduler_state_dict,
156152
"rng": self.get_rng_state(),
157153
}
158-
if self._extra_state_dict_thread is not None:
159-
self._extra_state_dict_thread.join()
160-
161-
def _save_extra_state_dict():
162-
torch.save(extra_state_dict, extra_path)
163-
log_with_rank(
164-
f"Saved extra_state to {os.path.abspath(extra_path)}",
165-
rank=self.rank,
166-
logger=logger,
167-
)
168-
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step))
169-
170-
self._extra_state_dict_thread = threading.Thread(
171-
target=_save_extra_state_dict,
154+
self._save_with_thread(
155+
extra_state_dict, local_path, "extra_state", "_extra_state_dict_thread", global_step
172156
)
173-
self._extra_state_dict_thread.start()
174157

175158
def save_state_dict( # noqa: C901
176159
self,
@@ -200,7 +183,11 @@ def save_state_dict( # noqa: C901
200183
self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg
201184
):
202185
self._save_model(local_path, global_step)
203-
ray.get(self.checkpoint_monitor.register_state_dict_save_count.remote(global_step, 1))
186+
ray.get(
187+
self.checkpoint_monitor.register_thread_count.remote(
188+
global_step, state_dict_thread_count=1
189+
)
190+
)
204191

205192
def save_checkpoint( # noqa: C901
206193
self,
@@ -259,7 +246,7 @@ def save_checkpoint( # noqa: C901
259246
), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']"
260247

261248
state_dict_thread_count = 0
262-
other_thread_count = 0
249+
checkpoint_thread_count = 0
263250

264251
# every rank will save its own model and optim shard
265252
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)
@@ -275,11 +262,11 @@ def save_checkpoint( # noqa: C901
275262
self._save_model(local_path, global_step)
276263

277264
if self.should_save_optimizer:
278-
other_thread_count += 1
265+
checkpoint_thread_count += 1
279266
self._save_optimizer(local_path, global_step)
280267

281268
if self.should_save_extra:
282-
other_thread_count += 1
269+
checkpoint_thread_count += 1
283270
self._save_extra_state(local_path, global_step)
284271

285272
if self.rank == 0:
@@ -333,7 +320,7 @@ def save_checkpoint( # noqa: C901
333320
state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True)
334321

335322
if self.rank == 0:
336-
other_thread_count += 1
323+
checkpoint_thread_count += 1
337324
hf_local_path = os.path.join(local_path, "huggingface")
338325
os.makedirs(hf_local_path, exist_ok=True)
339326

@@ -390,8 +377,10 @@ def _save_model():
390377
torch.distributed.barrier()
391378

392379
ray.get(
393-
self.checkpoint_monitor.register_checkpoint_save_count.remote(
394-
global_step, state_dict_thread_count, other_thread_count
380+
self.checkpoint_monitor.register_thread_count.remote(
381+
global_step,
382+
state_dict_thread_count=state_dict_thread_count,
383+
checkpoint_thread_count=checkpoint_thread_count,
395384
)
396385
)
397386
self.previous_saved_paths.append(local_path)

0 commit comments

Comments
 (0)