Skip to content

Commit b4224d3

Browse files
committed
apply suggestions from gemini
1 parent 7d8e99a commit b4224d3

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tests/trainer/trainer_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,13 +709,13 @@ def test_trainer(self):
709709
try:
710710
with open(state_dict_iteration_file, "r") as f:
711711
state_dict_iteration = int(f.read().strip())
712-
except Exception:
712+
except (IOError, ValueError):
713713
pass
714714
if os.path.exists(checkpoint_iteration_file):
715715
try:
716716
with open(checkpoint_iteration_file, "r") as f:
717717
checkpoint_iteration = int(f.read().strip())
718-
except Exception:
718+
except (IOError, ValueError):
719719
pass
720720

721721
if state_dict_iteration > 0:
@@ -777,7 +777,7 @@ def test_trainer(self):
777777
"special_tokens_map.json",
778778
},
779779
)
780-
print(f"Checkpoint check at {state_dict_iteration} iteration passed.")
780+
print(f"Checkpoint check at {checkpoint_iteration} iteration passed.")
781781

782782
time.sleep(1)
783783
trainer_process.join()

trinity/manager/synchronizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ async def _find_latest_state_dict(self) -> None:
8585
try:
8686
with open(local_latest_state_dict_iteration, "r") as f:
8787
latest_model_version = int(f.read().strip())
88-
except Exception:
88+
except (IOError, ValueError) as e:
89+
self.logger.warning(f"Failed to read or parse state dict iteration file: {e}")
8990
continue
9091
if latest_model_version > self.model_version:
9192
self.logger.info(
@@ -98,7 +99,7 @@ async def _find_latest_state_dict(self) -> None:
9899
self.config.trainer,
99100
)
100101
self.logger.info(
101-
f"Synchronizer has loaded model state dict from checkpoint {self.model_version}."
102+
f"Synchronizer has loaded model state dict from checkpoint {latest_model_version}."
102103
)
103104
await self.set_model_state_dict(model_state_dict, latest_model_version)
104105
await asyncio.sleep(1)

0 commit comments

Comments
 (0)