@@ -52,7 +52,7 @@ def on_checkpoint_coupled(
5252 fabric .save (ckpt_path , state )
5353 if replay_buffer is not None :
5454 self ._experiment_consistent_rb (replay_buffer , rb_state )
55- if self .keep_last :
55+ if fabric . is_global_zero and self .keep_last :
5656 self ._delete_old_checkpoints (pathlib .Path (ckpt_path ).parent )
5757
5858 def on_checkpoint_player (
@@ -71,7 +71,7 @@ def on_checkpoint_player(
7171 fabric .save (ckpt_path , state )
7272 if replay_buffer is not None :
7373 self ._experiment_consistent_rb (replay_buffer , rb_state )
74- if self .keep_last :
74+ if fabric . is_global_zero and self .keep_last :
7575 self ._delete_old_checkpoints (pathlib .Path (ckpt_path ).parent )
7676
7777 def on_checkpoint_trainer (
@@ -138,7 +138,7 @@ def _experiment_consistent_rb(
138138 # reinsert the open episodes to continue the training
139139 rb ._open_episodes = state
140140
141- def _delete_old_checkpoints (self , ckpt_folder : str | pathlib .Path ):
141+ def _delete_old_checkpoints (self , ckpt_folder : pathlib .Path ):
142142 ckpts = list (sorted (ckpt_folder .glob ("*.ckpt" ), key = os .path .getmtime ))
143143 if len (ckpts ) > self .keep_last :
144144 to_delete = ckpts [: - self .keep_last ]
0 commit comments