Skip to content

Commit d9e4fd4

Browse files
committed
make returning confidence head logits correct when sampling
1 parent 42ad45a commit d9e4fd4

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3639,19 +3639,17 @@ def forward(
36393639
if exists(atom_mask):
36403640
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
36413641

3642+
if return_confidence_head_logits:
3643+
assert exists(molecule_atom_indices)
3644+
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', sampled_atom_pos, molecule_atom_indices)
3645+
36423646
if exists(missing_atom_mask) and return_present_sampled_atoms:
36433647
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
36443648

36453649
if not return_confidence_head_logits:
36463650
return sampled_atom_pos
36473651

3648-
# todo - handle missing atoms
3649-
3650-
assert exists(molecule_atom_indices)
3651-
3652-
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', sampled_atom_pos, molecule_atom_indices)
3653-
3654-
logits = self.confidence_head(
3652+
confidence_head_logits = self.confidence_head(
36553653
single_repr = single.detach(),
36563654
single_inputs_repr = single_inputs.detach(),
36573655
pairwise_repr = pairwise.detach(),
@@ -3660,7 +3658,7 @@ def forward(
36603658
return_pae_logits = True
36613659
)
36623660

3663-
return sampled_atom_pos, logits
3661+
return sampled_atom_pos, confidence_head_logits
36643662

36653663
# if being forced to return loss, but do not have sufficient information to return losses, just return 0
36663664

@@ -3826,7 +3824,7 @@ def forward(
38263824

38273825
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, molecule_atom_indices)
38283826

3829-
logits = self.confidence_head(
3827+
ch_logits = self.confidence_head(
38303828
single_repr = single.detach(),
38313829
single_inputs_repr = single_inputs.detach(),
38323830
pairwise_repr = pairwise.detach(),
@@ -3837,19 +3835,19 @@ def forward(
38373835

38383836
if exists(pae_labels):
38393837
pae_labels = torch.where(pairwise_mask, pae_labels, ignore)
3840-
pae_loss = F.cross_entropy(logits.pae, pae_labels, ignore_index = ignore)
3838+
pae_loss = F.cross_entropy(ch_logits.pae, pae_labels, ignore_index = ignore)
38413839

38423840
if exists(pde_labels):
38433841
pde_labels = torch.where(pairwise_mask, pde_labels, ignore)
3844-
pde_loss = F.cross_entropy(logits.pde, pde_labels, ignore_index = ignore)
3842+
pde_loss = F.cross_entropy(ch_logits.pde, pde_labels, ignore_index = ignore)
38453843

38463844
if exists(plddt_labels):
38473845
plddt_labels = torch.where(mask, plddt_labels, ignore)
3848-
plddt_loss = F.cross_entropy(logits.plddt, plddt_labels, ignore_index = ignore)
3846+
plddt_loss = F.cross_entropy(ch_logits.plddt, plddt_labels, ignore_index = ignore)
38493847

38503848
if exists(resolved_labels):
38513849
resolved_labels = torch.where(mask, resolved_labels, ignore)
3852-
resolved_loss = F.cross_entropy(logits.resolved, resolved_labels, ignore_index = ignore)
3850+
resolved_loss = F.cross_entropy(ch_logits.resolved, resolved_labels, ignore_index = ignore)
38533851

38543852
confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
38553853

0 commit comments

Comments
 (0)