@@ -213,6 +213,7 @@ def __init__(
213213 default_lambda_lr = default_lambda_lr_fn ,
214214 fabric : Fabric | None = None ,
215215 accelerator = 'auto' ,
216+ checkpoint_prefix = 'af3.ckpt.' ,
216217 checkpoint_every : int = 1000 ,
217218 checkpoint_folder : str = './checkpoints' ,
218219 overwrite_checkpoints : bool = False ,
@@ -309,6 +310,7 @@ def __init__(
309310
310311 # checkpointing logic
311312
313+ self .checkpoint_prefix = checkpoint_prefix
312314 self .checkpoint_every = checkpoint_every
313315 self .overwrite_checkpoints = overwrite_checkpoints
314316 self .checkpoint_folder = Path (checkpoint_folder )
@@ -326,7 +328,12 @@ def is_main(self):
326328
327329 # saving and loading
328330
329- def save (self , path : str | Path , overwrite = False ):
331+ def save (
332+ self ,
333+ path : str | Path ,
334+ overwrite = False ,
335+ prefix : str | None = None
336+ ):
330337 if isinstance (path , str ):
331338 path = Path (path )
332339
@@ -343,7 +350,12 @@ def save(self, path: str | Path, overwrite = False):
343350
344351 torch .save (package , str (path ))
345352
346- def load (self , path : str | Path , strict = True ):
353+ def load (
354+ self ,
355+ path : str | Path ,
356+ strict = True ,
357+ prefix = None
358+ ):
347359 if isinstance (path , str ):
348360 path = Path (path )
349361
@@ -352,7 +364,9 @@ def load(self, path: str | Path, strict = True):
352364 # if the path is a directory, then automatically load latest checkpoint
353365
354366 if path .is_dir ():
355- model_paths = [* path .glob ('**/*.pt' )]
367+ prefix = default (prefix , self .checkpoint_prefix )
368+
369+ model_paths = [* path .glob (f'**/{ prefix } *.pt' )]
356370
357371 assert len (model_paths ) > 0 , f'no files found in directory { path } '
358372
@@ -500,7 +514,7 @@ def __call__(
500514 self .wait ()
501515
502516 if self .is_main and divisible_by (self .steps , self .checkpoint_every ):
503- checkpoint_path = self .checkpoint_folder / f'af3.ckpt. { self .steps } .pt'
517+ checkpoint_path = self .checkpoint_folder / f'{ self . checkpoint_prefix } { self .steps } .pt'
504518
505519 self .save (checkpoint_path , overwrite = self .overwrite_checkpoints )
506520
0 commit comments