File tree Expand file tree Collapse file tree 3 files changed +20
-2
lines changed Expand file tree Collapse file tree 3 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -110,6 +110,8 @@ def __init__(
110110 default_lambda_lr = default_lambda_lr_fn ,
111111 fabric : Fabric | None = None ,
112112 accelerator = 'auto' ,
113+ checkpoint_every : int = 1000 ,
114+ checkpoint_folder : str = './checkpoints' ,
113115 fabric_kwargs : dict = dict (),
114116 ema_kwargs : dict = dict ()
115117 ):
@@ -194,6 +196,14 @@ def __init__(
194196
195197 self .steps = 0
196198
199+ # checkpointing logic
200+
201+ self .checkpoint_every = checkpoint_every
202+ self .checkpoint_folder = Path (checkpoint_folder )
203+
204+ self .checkpoint_folder .mkdir (exist_ok = True , parents = True )
205+ assert self .checkpoint_folder .is_dir ()
206+
197207 @property
198208 def is_main (self ):
199209 return self .fabric .global_rank == 0
@@ -356,6 +366,11 @@ def __call__(
356366
357367 self .wait ()
358368
369+ if self .is_main and divisible_by (self .steps , self .checkpoint_every ):
370+ self .save (self .checkpoint_folder / f'af3.ckpt.{ self .steps } .pt' )
371+
372+ self .wait ()
373+
359374 # maybe test
360375
361376 if self .is_main and self .needs_test :
Original file line number Diff line number Diff line change 11[project ]
22name = " alphafold3-pytorch"
3- version = " 0.0.60 "
3+ version = " 0.0.61 "
44description = " Alphafold 3 - Pytorch"
55authors = [
66 {
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change @@ -132,11 +132,14 @@ def test_trainer():
132132 num_train_steps = 2 ,
133133 batch_size = 1 ,
134134 valid_every = 1 ,
135- grad_accum_every = 2
135+ grad_accum_every = 2 ,
136+ checkpoint_every = 1
136137 )
137138
138139 trainer ()
139140
141+ assert Path ('./checkpoints/af3.ckpt.1.pt' ).exists ()
142+
140143 # saving and loading from trainer
141144
142145 trainer .save ('./some/nested/folder2/training' , overwrite = True )
You can’t perform that action at this time.
0 commit comments