File tree Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -112,6 +112,7 @@ def __init__(
112112 accelerator = 'auto' ,
113113 checkpoint_every : int = 1000 ,
114114 checkpoint_folder : str = './checkpoints' ,
115+ overwrite_checkpoints : bool = False ,
115116 fabric_kwargs : dict = dict (),
116117 ema_kwargs : dict = dict ()
117118 ):
@@ -199,6 +200,7 @@ def __init__(
199200 # checkpointing logic
200201
201202 self .checkpoint_every = checkpoint_every
203+ self .overwrite_checkpoints = overwrite_checkpoints
202204 self .checkpoint_folder = Path (checkpoint_folder )
203205
204206 self .checkpoint_folder .mkdir (exist_ok = True , parents = True )
@@ -367,7 +369,9 @@ def __call__(
367369 self .wait ()
368370
369371 if self .is_main and divisible_by (self .steps , self .checkpoint_every ):
370- self .save (self .checkpoint_folder / f'af3.ckpt.{ self .steps } .pt' )
372+ checkpoint_path = self .checkpoint_folder / f'af3.ckpt.{ self .steps } .pt'
373+
374+ self .save (checkpoint_path , overwrite = self .overwrite_checkpoints )
371375
372376 self .wait ()
373377
Original file line number Diff line number Diff line change 11[project ]
22name = " alphafold3-pytorch"
3- version = " 0.0.61 "
3+ version = " 0.0.62 "
44description = " Alphafold 3 - Pytorch"
55authors = [
66 {
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change 11import os
22os .environ ['TYPECHECK' ] = 'True'
33
4+ from pathlib import Path
5+
46import pytest
57import torch
68from torch .utils .data import Dataset , DataLoader
@@ -133,7 +135,8 @@ def test_trainer():
133135 batch_size = 1 ,
134136 valid_every = 1 ,
135137 grad_accum_every = 2 ,
136- checkpoint_every = 1
138+ checkpoint_every = 1 ,
139+ overwrite_checkpoints = True
137140 )
138141
139142 trainer ()
You can’t perform that action at this time.
0 commit comments