Skip to content

Commit 8d3a6ed

Browse files
authored
Update alphafold3.py (#278)
1 parent 2625953 commit 8d3a6ed

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7463,6 +7463,7 @@ def forward(
74637463
# determine which mask to use for confidence head labels
74647464

74657465
label_mask = atom_mask
7466+
label_pairwise_mask = to_pairwise_mask(mask)
74667467

74677468
# cross entropy losses
74687469

@@ -7498,14 +7499,14 @@ def cross_entropy_with_weight(
74987499
f"pae_labels shape {pae_labels.shape[-1]} does not match "
74997500
f"ch_logits.pae shape {ch_logits.pae.shape[-1]}"
75007501
)
7501-
pae_loss = cross_entropy_with_weight(ch_logits.pae, pae_labels, confidence_weight, pairwise_mask, ignore)
7502+
pae_loss = cross_entropy_with_weight(ch_logits.pae, pae_labels, confidence_weight, label_pairwise_mask, ignore)
75027503

75037504
if exists(pde_labels):
75047505
assert pde_labels.shape[-1] == ch_logits.pde.shape[-1], (
75057506
f"pde_labels shape {pde_labels.shape[-1]} does not match "
75067507
f"ch_logits.pde shape {ch_logits.pde.shape[-1]}"
75077508
)
7508-
pde_loss = cross_entropy_with_weight(ch_logits.pde, pde_labels, confidence_weight, pairwise_mask, ignore)
7509+
pde_loss = cross_entropy_with_weight(ch_logits.pde, pde_labels, confidence_weight, label_pairwise_mask, ignore)
75097510

75107511
if exists(plddt_labels):
75117512
assert plddt_labels.shape[-1] == ch_logits.plddt.shape[-1], (

0 commit comments

Comments
 (0)