11import os
22os .environ ['TYPECHECK' ] = 'True'
33
4+ import shutil
45from pathlib import Path
56from random import randrange , random
67from dataclasses import asdict
@@ -100,7 +101,13 @@ def __getitem__(self, idx):
100101 resolved_labels = resolved_labels
101102 )
102103
103- def test_trainer ():
104+ @pytest .fixture ()
105+ def remove_test_folders ():
106+ yield
107+ shutil .rmtree ('./test-folder' )
108+
109+ def test_trainer (remove_test_folders ):
110+
104111 alphafold3 = Alphafold3 (
105112 dim_atom_inputs = 77 ,
106113 dim_template_feats = 44 ,
@@ -137,7 +144,7 @@ def test_trainer():
137144 _ , breakdown = alphafold3 (** asdict (inputs ), return_loss_breakdown = True )
138145 before_distogram = breakdown .distogram
139146
140- path = './some /nested/folder/af3'
147+ path = './test-folder /nested/folder/af3'
141148 alphafold3 .save (path , overwrite = True )
142149
143150 # load from scratch, along with saved hyperparameters
@@ -163,6 +170,7 @@ def test_trainer():
163170 valid_every = 1 ,
164171 grad_accum_every = 2 ,
165172 checkpoint_every = 1 ,
173+ checkpoint_folder = './test-folder/checkpoints' ,
166174 overwrite_checkpoints = True ,
167175 ema_kwargs = dict (
168176 use_foreach = True ,
@@ -175,26 +183,26 @@ def test_trainer():
175183
176184 # assert checkpoints created
177185
178- assert Path (f'./checkpoints/({ trainer .train_id } )_af3.ckpt.1.pt' ).exists ()
186+ assert Path (f'./test-folder/ checkpoints/({ trainer .train_id } )_af3.ckpt.1.pt' ).exists ()
179187
180188 # assert can load latest checkpoint by loading from a directory
181189
182- trainer .load ('./checkpoints' , strict = False )
190+ trainer .load ('./test-folder/ checkpoints' , strict = False )
183191
184192 assert exists (trainer .model_loaded_from_path )
185193
186194 # saving and loading from trainer
187195
188- trainer .save ('./some /nested/folder2/training.pt' , overwrite = True )
189- trainer .load ('./some /nested/folder2/training.pt' , strict = False )
196+ trainer .save ('./test-folder /nested/folder2/training.pt' , overwrite = True )
197+ trainer .load ('./test-folder /nested/folder2/training.pt' , strict = False )
190198
191199 # allow for only loading model, needed for fine-tuning logic
192200
193- trainer .load ('./some /nested/folder2/training.pt' , only_model = True , strict = False )
201+ trainer .load ('./test-folder /nested/folder2/training.pt' , only_model = True , strict = False )
194202
195203 # also allow for loading Alphafold3 directly from training ckpt
196204
197- alphafold3 = Alphafold3 .init_and_load ('./some /nested/folder2/training.pt' )
205+ alphafold3 = Alphafold3 .init_and_load ('./test-folder /nested/folder2/training.pt' )
198206
199207# test use of collation fn outside of trainer
200208
0 commit comments