Skip to content

Commit d66c56c

Browse files
committed
able to force sampling or return of loss, no matter the state of the labels being passed in
1 parent 4d4cfce commit d66c56c

File tree

3 files changed

+92
-15
lines changed

3 files changed

+92
-15
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3161,7 +3161,7 @@ def forward(
31613161
plddt_labels: Int['b n'] | None = None,
31623162
resolved_labels: Int['b n'] | None = None,
31633163
return_loss_breakdown = False,
3164-
return_loss_if_possible: bool = True,
3164+
return_loss: bool = None,
31653165
num_rollout_steps: int = 20,
31663166
rollout_show_tqdm_pbar: bool = False
31673167
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
@@ -3317,11 +3317,15 @@ def forward(
33173317

33183318
has_labels = any([*map(exists, all_labels)])
33193319

3320-
return_loss = atom_pos_given or has_labels
3320+
can_return_loss = atom_pos_given or has_labels
3321+
3322+
# default whether to return loss by whether labels or atom positions are given
3323+
3324+
return_loss = default(return_loss, can_return_loss)
33213325

33223326
# if neither atom positions or any labels are passed in, sample a structure and return
33233327

3324-
if not return_loss_if_possible or not return_loss:
3328+
if not return_loss:
33253329
return self.edm.sample(
33263330
num_sample_steps = num_sample_steps,
33273331
atom_feats = atom_feats,
@@ -3335,6 +3339,16 @@ def forward(
33353339
molecule_atom_lens = molecule_atom_lens
33363340
)
33373341

3342+
# if being forced to return loss, but do not have sufficient information to return losses, just return 0
3343+
3344+
if return_loss and not can_return_loss:
3345+
zero = self.zero.requires_grad_()
3346+
3347+
if not return_loss_breakdown:
3348+
return zero
3349+
3350+
return zero, LossBreakdown(*((zero,) * 11))
3351+
33383352
# losses default to 0
33393353

33403354
distogram_loss = diffusion_loss = confidence_loss = pae_loss = pde_loss = plddt_loss = resolved_loss = self.zero

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.41"
3+
version = "0.1.42"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
atom_ref_pos_to_atompair_inputs
3838
)
3939

40-
def join(str, delimiter = ','):
41-
return delimiter.join(str)
42-
4340
def test_atom_ref_pos_to_atompair_inputs():
4441
atom_ref_pos = torch.randn(16, 3)
4542
atom_ref_space_uid = torch.ones(16).long()
@@ -409,14 +406,8 @@ def test_distogram_head():
409406

410407
logits = distogram_head(pairwise_repr)
411408

412-
@pytest.mark.parametrize(
413-
join([
414-
'window_atompair_inputs',
415-
'stochastic_frame_average'
416-
]), [
417-
(True, False),
418-
(True, False)
419-
])
409+
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
410+
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
420411
def test_alphafold3(
421412
window_atompair_inputs: bool,
422413
stochastic_frame_average: bool
@@ -572,6 +563,78 @@ def test_alphafold3_without_msa_and_templates():
572563

573564
loss.backward()
574565

566+
def test_alphafold3_force_return_loss():
567+
seq_len = 16
568+
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
569+
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
570+
571+
atom_inputs = torch.randn(2, atom_seq_len, 77)
572+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
573+
additional_molecule_feats = torch.randn(2, seq_len, 10)
574+
575+
atom_pos = torch.randn(2, atom_seq_len, 3)
576+
molecule_atom_indices = molecule_atom_lens - 1
577+
578+
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
579+
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
580+
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
581+
plddt_labels = torch.randint(0, 50, (2, seq_len))
582+
resolved_labels = torch.randint(0, 2, (2, seq_len))
583+
584+
alphafold3 = Alphafold3(
585+
dim_atom_inputs = 77,
586+
dim_template_feats = 44,
587+
num_dist_bins = 38,
588+
confidence_head_kwargs = dict(
589+
pairformer_depth = 1
590+
),
591+
template_embedder_kwargs = dict(
592+
pairformer_stack_depth = 1
593+
),
594+
msa_module_kwargs = dict(
595+
depth = 1
596+
),
597+
pairformer_stack = dict(
598+
depth = 2
599+
),
600+
diffusion_module_kwargs = dict(
601+
atom_encoder_depth = 1,
602+
token_transformer_depth = 1,
603+
atom_decoder_depth = 1,
604+
),
605+
)
606+
607+
sampled_atom_pos = alphafold3(
608+
num_recycling_steps = 2,
609+
atom_inputs = atom_inputs,
610+
molecule_atom_lens = molecule_atom_lens,
611+
atompair_inputs = atompair_inputs,
612+
additional_molecule_feats = additional_molecule_feats,
613+
atom_pos = atom_pos,
614+
molecule_atom_indices = molecule_atom_indices,
615+
distance_labels = distance_labels,
616+
pae_labels = pae_labels,
617+
pde_labels = pde_labels,
618+
plddt_labels = plddt_labels,
619+
resolved_labels = resolved_labels,
620+
return_loss_breakdown = True,
621+
return_loss = False # force sampling even if labels are given
622+
)
623+
624+
assert sampled_atom_pos.ndim == 3
625+
626+
loss, _ = alphafold3(
627+
num_recycling_steps = 2,
628+
atom_inputs = atom_inputs,
629+
molecule_atom_lens = molecule_atom_lens,
630+
atompair_inputs = atompair_inputs,
631+
additional_molecule_feats = additional_molecule_feats,
632+
return_loss_breakdown = True,
633+
return_loss = True # force returning loss even if no labels given
634+
)
635+
636+
assert loss == 0.
637+
575638
# test creation from config
576639

577640
def test_alphafold3_config():

0 commit comments

Comments
 (0)