File tree Expand file tree Collapse file tree 3 files changed +12
-2
lines changed Expand file tree Collapse file tree 3 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -354,7 +354,8 @@ def load(
354354 self ,
355355 path : str | Path ,
356356 strict = True ,
357- prefix = None
357+ prefix = None ,
358+ only_model = False
358359 ):
359360 if isinstance (path , str ):
360361 path = Path (path )
@@ -382,6 +383,11 @@ def load(
382383
383384 self .model .load (path )
384385
386+ if only_model :
387+ return
388+
389+ # load optimizer and scheduler states
390+
385391 package = torch .load (str (path ))
386392
387393 if 'optimizer' in package :
Original file line number Diff line number Diff line change 11[project ]
22name = " alphafold3-pytorch"
3- version = " 0.1.25 "
3+ version = " 0.1.26 "
44description = " Alphafold 3 - Pytorch"
55authors = [
66 {
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change @@ -164,6 +164,10 @@ def test_trainer():
164164 trainer .save ('./some/nested/folder2/training.pt' , overwrite = True )
165165 trainer .load ('./some/nested/folder2/training.pt' )
166166
167+ # allow for only loading model, needed for fine-tuning logic
168+
169+ trainer .load ('./some/nested/folder2/training.pt' , only_model = True )
170+
167171 # also allow for loading Alphafold3 directly from training ckpt
168172
169173 alphafold3 = Alphafold3 .init_and_load ('./some/nested/folder2/training.pt' )
You can’t perform that action at this time.
0 commit comments