Skip to content

Commit 3598f6e

Browse files
authored
fix tests
1 parent f14b816 commit 3598f6e

File tree

5 files changed

+64
-19
lines changed

5 files changed

+64
-19
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ jobs:
66

77
runs-on: ubuntu-latest
88
timeout-minutes: 20
9+
strategy:
10+
fail-fast: false
911

1012
steps:
1113
- uses: actions/checkout@v4

alphafold3_pytorch/trainer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def __init__(
267267
checkpoint_folder: str = './checkpoints',
268268
overwrite_checkpoints: bool = False,
269269
fabric_kwargs: dict = dict(),
270+
use_ema: bool = True,
270271
ema_kwargs: dict = dict(
271272
use_foreach = True
272273
)
@@ -285,7 +286,10 @@ def __init__(
285286

286287
# exponential moving average
287288

288-
if self.is_main:
289+
self.ema_model = None
290+
self.has_ema = self.is_main and use_ema
291+
292+
if self.has_ema:
289293
self.ema_model = EMA(
290294
model,
291295
beta = ema_decay,
@@ -574,7 +578,7 @@ def __call__(
574578

575579
self.wait()
576580

577-
if self.is_main:
581+
if self.has_ema:
578582
self.ema_model.update()
579583

580584
self.wait()
@@ -593,14 +597,16 @@ def __call__(
593597
self.needs_valid and
594598
divisible_by(self.steps, self.valid_every)
595599
):
600+
eval_model = default(self.ema_model, self.model)
601+
596602
with torch.no_grad():
597-
self.ema_model.eval()
603+
eval_model.eval()
598604

599605
total_valid_loss = 0.
600606
valid_loss_breakdown = None
601607

602608
for valid_batch in self.valid_dataloader:
603-
valid_loss, loss_breakdown = self.ema_model(
609+
valid_loss, loss_breakdown = eval_model(
604610
**valid_batch.dict(),
605611
return_loss_breakdown = True
606612
)
@@ -631,14 +637,16 @@ def __call__(
631637
# maybe test
632638

633639
if self.is_main and self.needs_test:
640+
eval_model = default(self.ema_model, self.model)
641+
634642
with torch.no_grad():
635-
self.ema_model.eval()
643+
eval_model.eval()
636644

637645
total_test_loss = 0.
638646
test_loss_breakdown = None
639647

640648
for test_batch in self.test_dataloader:
641-
test_loss, loss_breakdown = self.ema_model(
649+
test_loss, loss_breakdown = eval_model(
642650
**test_batch.dict(),
643651
return_loss_breakdown = True
644652
)

tests/test_af3.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def test_alphafold3(
504504
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
505505
template_mask = torch.ones((2, 2)).bool()
506506

507-
msa = torch.randn(2, 7, seq_len, 64)
507+
msa = torch.randn(2, 7, seq_len, 8)
508508
msa_mask = torch.ones((2, 7)).bool()
509509

510510
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -519,7 +519,9 @@ def test_alphafold3(
519519

520520
alphafold3 = Alphafold3(
521521
dim_atom_inputs = 77,
522-
dim_pairwise = 64,
522+
dim_pairwise = 8,
523+
dim_single = 8,
524+
dim_token = 8,
523525
atoms_per_window = atoms_per_window,
524526
dim_template_feats = 44,
525527
num_dist_bins = 38,
@@ -531,15 +533,28 @@ def test_alphafold3(
531533
pairformer_stack_depth = 1
532534
),
533535
msa_module_kwargs = dict(
534-
depth = 1
536+
depth = 1,
537+
dim_msa = 8,
535538
),
536-
pairformer_stack = dict(
537-
depth = 2
539+
pairformer_stack=dict(
540+
depth=1,
541+
pair_bias_attn_dim_head = 4,
542+
pair_bias_attn_heads = 2,
538543
),
539-
diffusion_module_kwargs = dict(
540-
atom_encoder_depth = 1,
541-
token_transformer_depth = 1,
542-
atom_decoder_depth = 1,
544+
diffusion_module_kwargs=dict(
545+
atom_encoder_depth=1,
546+
token_transformer_depth=1,
547+
atom_decoder_depth=1,
548+
atom_decoder_kwargs = dict(
549+
attn_pair_bias_kwargs = dict(
550+
dim_head = 4
551+
)
552+
),
553+
atom_encoder_kwargs = dict(
554+
attn_pair_bias_kwargs = dict(
555+
dim_head = 4
556+
)
557+
)
543558
),
544559
stochastic_frame_average = stochastic_frame_average,
545560
confidence_head_atom_resolution = confidence_head_atom_resolution
@@ -569,6 +584,7 @@ def test_alphafold3(
569584
pde_labels = pde_labels,
570585
plddt_labels = plddt_labels,
571586
resolved_labels = resolved_labels,
587+
num_rollout_steps = 1,
572588
diffusion_add_smooth_lddt_loss = True,
573589
return_loss_breakdown = True
574590
)

tests/test_input.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,25 @@ def test_pdbinput_input():
174174
confidence_head_kwargs=dict(pairformer_depth=1),
175175
template_embedder_kwargs=dict(pairformer_stack_depth=1),
176176
msa_module_kwargs=dict(depth=1),
177-
pairformer_stack=dict(depth=1),
177+
pairformer_stack=dict(
178+
depth=1,
179+
pair_bias_attn_dim_head = 4,
180+
pair_bias_attn_heads = 2,
181+
),
178182
diffusion_module_kwargs=dict(
179183
atom_encoder_depth=1,
180184
token_transformer_depth=1,
181185
atom_decoder_depth=1,
186+
atom_decoder_kwargs = dict(
187+
attn_pair_bias_kwargs = dict(
188+
dim_head = 4
189+
)
190+
),
191+
atom_encoder_kwargs = dict(
192+
attn_pair_bias_kwargs = dict(
193+
dim_head = 4
194+
)
195+
)
182196
),
183197
)
184198

tests/test_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def populate_mock_pdb_and_remove_test_folders():
223223
valid_folder.mkdir(exist_ok = True, parents = True)
224224
test_folder.mkdir(exist_ok = True, parents = True)
225225

226-
for i in range(100):
226+
for i in range(2):
227227
shutil.copy2(str(working_cif_file), str(train_folder / f'{i}.cif'))
228228

229229
for i in range(1):
@@ -287,7 +287,9 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
287287
inputs = next(iter(dataloader))
288288

289289
alphafold3.eval()
290-
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
290+
with torch.no_grad():
291+
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
292+
291293
before_distogram = breakdown.distogram
292294

293295
path = './test-folder/nested/folder/af3'
@@ -298,7 +300,9 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
298300
alphafold3 = Alphafold3.init_and_load(path)
299301

300302
alphafold3.eval()
301-
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
303+
with torch.no_grad():
304+
_, breakdown = alphafold3(**asdict(inputs), return_loss_breakdown = True)
305+
302306
after_distogram = breakdown.distogram
303307

304308
assert torch.allclose(before_distogram, after_distogram)
@@ -318,6 +322,7 @@ def test_trainer_with_pdb_input(populate_mock_pdb_and_remove_test_folders):
318322
checkpoint_every = 1,
319323
checkpoint_folder = './test-folder/checkpoints',
320324
overwrite_checkpoints = True,
325+
use_ema = False,
321326
ema_kwargs = dict(
322327
use_foreach = True,
323328
update_after_step = 0,

0 commit comments

Comments
 (0)