Skip to content

Commit 2aa9ed3

Browse files
committed
move test related files under a test folder and remove it after testing
1 parent a424bc4 commit 2aa9ed3

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

tests/test_trainer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
os.environ['TYPECHECK'] = 'True'
33

4+
import shutil
45
from pathlib import Path
56
from random import randrange, random
67
from dataclasses import asdict
@@ -100,7 +101,13 @@ def __getitem__(self, idx):
100101
resolved_labels = resolved_labels
101102
)
102103

103-
def test_trainer():
104+
@pytest.fixture()
105+
def remove_test_folders():
106+
yield
107+
shutil.rmtree('./test-folder')
108+
109+
def test_trainer(remove_test_folders):
110+
104111
alphafold3 = Alphafold3(
105112
dim_atom_inputs = 77,
106113
dim_template_feats = 44,
@@ -137,7 +144,7 @@ def test_trainer():
137144
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
138145
before_distogram = breakdown.distogram
139146

140-
path = './some/nested/folder/af3'
147+
path = './test-folder/nested/folder/af3'
141148
alphafold3.save(path, overwrite = True)
142149

143150
# load from scratch, along with saved hyperparameters
@@ -163,6 +170,7 @@ def test_trainer():
163170
valid_every = 1,
164171
grad_accum_every = 2,
165172
checkpoint_every = 1,
173+
checkpoint_folder = './test-folder/checkpoints',
166174
overwrite_checkpoints = True,
167175
ema_kwargs = dict(
168176
use_foreach = True,
@@ -175,26 +183,26 @@ def test_trainer():
175183

176184
# assert checkpoints created
177185

178-
assert Path(f'./checkpoints/({trainer.train_id})_af3.ckpt.1.pt').exists()
186+
assert Path(f'./test-folder/checkpoints/({trainer.train_id})_af3.ckpt.1.pt').exists()
179187

180188
# assert can load latest checkpoint by loading from a directory
181189

182-
trainer.load('./checkpoints', strict = False)
190+
trainer.load('./test-folder/checkpoints', strict = False)
183191

184192
assert exists(trainer.model_loaded_from_path)
185193

186194
# saving and loading from trainer
187195

188-
trainer.save('./some/nested/folder2/training.pt', overwrite = True)
189-
trainer.load('./some/nested/folder2/training.pt', strict = False)
196+
trainer.save('./test-folder/nested/folder2/training.pt', overwrite = True)
197+
trainer.load('./test-folder/nested/folder2/training.pt', strict = False)
190198

191199
# allow for only loading model, needed for fine-tuning logic
192200

193-
trainer.load('./some/nested/folder2/training.pt', only_model = True, strict = False)
201+
trainer.load('./test-folder/nested/folder2/training.pt', only_model = True, strict = False)
194202

195203
# also allow for loading Alphafold3 directly from training ckpt
196204

197-
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training.pt')
205+
alphafold3 = Alphafold3.init_and_load('./test-folder/nested/folder2/training.pt')
198206

199207
# test use of collation fn outside of trainer
200208

0 commit comments

Comments
 (0)