Skip to content

Commit 92e052a

Browse files
committed
try again to get tests to complete on github
1 parent c5e8b20 commit 92e052a

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

tests/test_trainer.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ def populate_mock_pdb_and_remove_test_folders():
226226
for i in range(100):
227227
shutil.copy2(str(working_cif_file), str(train_folder / f'{i}.cif'))
228228

229-
for i in range(4):
229+
for i in range(1):
230230
shutil.copy2(str(working_cif_file), str(valid_folder / f'{i}.cif'))
231231

232-
for i in range(2):
232+
for i in range(1):
233233
shutil.copy2(str(working_cif_file), str(test_folder / f'{i}.cif'))
234234

235235
yield
@@ -239,25 +239,41 @@ def populate_mock_pdb_and_remove_test_folders():
239239
def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
240240

241241
alphafold3 = Alphafold3(
242-
dim_atom=8,
243-
dim_atompair=8,
244-
dim_input_embedder_token=8,
245-
dim_single=8,
246-
dim_pairwise=8,
247-
dim_token=8,
242+
dim_atom=4,
243+
dim_atompair=4,
244+
dim_input_embedder_token=4,
245+
dim_single=4,
246+
dim_pairwise=4,
247+
dim_token=4,
248248
dim_atom_inputs=3,
249249
dim_atompair_inputs=1,
250250
atoms_per_window=27,
251251
dim_template_feats=44,
252252
num_dist_bins=38,
253-
confidence_head_kwargs=dict(pairformer_depth=1),
253+
confidence_head_kwargs=dict(
254+
pairformer_depth=1,
255+
),
254256
template_embedder_kwargs=dict(pairformer_stack_depth=1),
255257
msa_module_kwargs=dict(depth=1),
256-
pairformer_stack=dict(depth=1),
258+
pairformer_stack=dict(
259+
depth=1,
260+
pair_bias_attn_dim_head = 4,
261+
pair_bias_attn_heads = 2,
262+
),
257263
diffusion_module_kwargs=dict(
258264
atom_encoder_depth=1,
259265
token_transformer_depth=1,
260266
atom_decoder_depth=1,
267+
atom_decoder_kwargs = dict(
268+
attn_pair_bias_kwargs = dict(
269+
dim_head = 4
270+
)
271+
),
272+
atom_encoder_kwargs = dict(
273+
attn_pair_bias_kwargs = dict(
274+
dim_head = 4
275+
)
276+
)
261277
),
262278
)
263279

@@ -267,7 +283,7 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
267283

268284
# test saving and loading from Alphafold3, independent of lightning
269285

270-
dataloader = DataLoader(dataset, batch_size = 2)
286+
dataloader = DataLoader(dataset, batch_size = 1)
271287
inputs = next(iter(dataloader))
272288

273289
alphafold3.eval()
@@ -298,7 +314,7 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
298314
num_train_steps = 2,
299315
batch_size = 1,
300316
valid_every = 1,
301-
grad_accum_every = 2,
317+
grad_accum_every = 1,
302318
checkpoint_every = 1,
303319
checkpoint_folder = './test-folder/checkpoints',
304320
overwrite_checkpoints = True,

0 commit comments

Comments
 (0)