Skip to content

Commit 9218948

Browse files
Fix/delete old ckpts (#174)
* fix: distributed training * fix: typing * release: version 0.5.1
1 parent 0453284 commit 9218948

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ warn_return_any = false
196196

197197

198198
[tool.bumpver]
199-
current_version = "0.5.0"
199+
current_version = "0.5.1"
200200
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
201201
commit_message = "bump version {old_version} -> {new_version}"
202202
tag_message = "v{new_version}"

sheeprl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
np.int = np.int64
5151
np.bool = bool
5252

53-
__version__ = "0.5.0"
53+
__version__ = "0.5.1"
5454

5555

5656
# Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10

sheeprl/utils/callback.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def on_checkpoint_coupled(
5252
fabric.save(ckpt_path, state)
5353
if replay_buffer is not None:
5454
self._experiment_consistent_rb(replay_buffer, rb_state)
55-
if self.keep_last:
55+
if fabric.is_global_zero and self.keep_last:
5656
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)
5757

5858
def on_checkpoint_player(
@@ -71,7 +71,7 @@ def on_checkpoint_player(
7171
fabric.save(ckpt_path, state)
7272
if replay_buffer is not None:
7373
self._experiment_consistent_rb(replay_buffer, rb_state)
74-
if self.keep_last:
74+
if fabric.is_global_zero and self.keep_last:
7575
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)
7676

7777
def on_checkpoint_trainer(
@@ -138,7 +138,7 @@ def _experiment_consistent_rb(
138138
# reinsert the open episodes to continue the training
139139
rb._open_episodes = state
140140

141-
def _delete_old_checkpoints(self, ckpt_folder: str | pathlib.Path):
141+
def _delete_old_checkpoints(self, ckpt_folder: pathlib.Path):
142142
ckpts = list(sorted(ckpt_folder.glob("*.ckpt"), key=os.path.getmtime))
143143
if len(ckpts) > self.keep_last:
144144
to_delete = ckpts[: -self.keep_last]

0 commit comments

Comments
 (0)