Skip to content

Commit 56cabcc

Browse files
committed
able to set a custom prefix for checkpoints
1 parent 7c664d3 commit 56cabcc

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(
213213
default_lambda_lr = default_lambda_lr_fn,
214214
fabric: Fabric | None = None,
215215
accelerator = 'auto',
216+
checkpoint_prefix = 'af3.ckpt.',
216217
checkpoint_every: int = 1000,
217218
checkpoint_folder: str = './checkpoints',
218219
overwrite_checkpoints: bool = False,
@@ -309,6 +310,7 @@ def __init__(
309310

310311
# checkpointing logic
311312

313+
self.checkpoint_prefix = checkpoint_prefix
312314
self.checkpoint_every = checkpoint_every
313315
self.overwrite_checkpoints = overwrite_checkpoints
314316
self.checkpoint_folder = Path(checkpoint_folder)
@@ -326,7 +328,12 @@ def is_main(self):
326328

327329
# saving and loading
328330

329-
def save(self, path: str | Path, overwrite = False):
331+
def save(
332+
self,
333+
path: str | Path,
334+
overwrite = False,
335+
prefix: str | None = None
336+
):
330337
if isinstance(path, str):
331338
path = Path(path)
332339

@@ -343,7 +350,12 @@ def save(self, path: str | Path, overwrite = False):
343350

344351
torch.save(package, str(path))
345352

346-
def load(self, path: str | Path, strict = True):
353+
def load(
354+
self,
355+
path: str | Path,
356+
strict = True,
357+
prefix = None
358+
):
347359
if isinstance(path, str):
348360
path = Path(path)
349361

@@ -352,7 +364,9 @@ def load(self, path: str | Path, strict = True):
352364
# if the path is a directory, then automatically load latest checkpoint
353365

354366
if path.is_dir():
355-
model_paths = [*path.glob('**/*.pt')]
367+
prefix = default(prefix, self.checkpoint_prefix)
368+
369+
model_paths = [*path.glob(f'**/{prefix}*.pt')]
356370

357371
assert len(model_paths) > 0, f'no files found in directory {path}'
358372

@@ -500,7 +514,7 @@ def __call__(
500514
self.wait()
501515

502516
if self.is_main and divisible_by(self.steps, self.checkpoint_every):
503-
checkpoint_path = self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt'
517+
checkpoint_path = self.checkpoint_folder / f'{self.checkpoint_prefix}{self.steps}.pt'
504518

505519
self.save(checkpoint_path, overwrite = self.overwrite_checkpoints)
506520

0 commit comments

Comments
 (0)