Skip to content

Commit 6300cbe

Browse files
committed
complete migration of PDE to atom resolution
1 parent e54c050 commit 6300cbe

File tree

3 files changed

+27
-41
lines changed

3 files changed

+27
-41
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,7 +3428,7 @@ def forward(
34283428

34293429
class ConfidenceHeadLogits(NamedTuple):
34303430
pae: Float['b pae m m'] | None
3431-
pde: Float['b pde n n']
3431+
pde: Float['b pde m m']
34323432
plddt: Float['b plddt m']
34333433
resolved: Float['b 2 m']
34343434

@@ -3567,7 +3567,7 @@ def forward(
35673567

35683568
# to logits
35693569

3570-
pde_logits = self.to_pde_logits(symmetrize(pairwise_repr))
3570+
pde_logits = self.to_pde_logits(symmetrize(atom_pairwise_repr))
35713571

35723572
plddt_logits = self.to_plddt_logits(atom_single_repr)
35733573
resolved_logits = self.to_resolved_logits(atom_single_repr)
@@ -4344,7 +4344,7 @@ def can_calculate_unresolved_protein_rasa(self):
43444344
@typecheck
43454345
def compute_gpde(
43464346
self,
4347-
pde_logits: Float["b pde n n"],
4347+
pde_logits: Float["b pde n n"],
43484348
dist_logits: Float["b dist n n"],
43494349
dist_breaks: Float[" dist_break"],
43504350
tok_repr_atm_mask: Bool["b n"],
@@ -4852,6 +4852,7 @@ def compute_model_selection_score(
48524852
top_ranked_sample = max(
48534853
scored_samples, key=lambda x: x[-1].mean()
48544854
) # rank by batch-averaged gPDE
4855+
48554856
best_of_5_sample = max(
48564857
scored_samples, key=lambda x: x[-2].mean()
48574858
) # rank by batch-averaged lDDT
@@ -5898,29 +5899,11 @@ def forward(
58985899
pde_labels = None
58995900

59005901
if atom_pos_given:
5901-
denoised_molecule_pos = None
5902-
5903-
assert exists(
5904-
molecule_atom_indices
5905-
), "`molecule_atom_indices` must be passed in for calculating non-atomic PDE labels"
5906-
5907-
# molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, molecule_atom_indices)
5908-
5909-
mol_atom_indices = repeat(
5910-
molecule_atom_indices, "b n -> b n c", c=atom_pos.shape[-1]
5911-
)
59125902

5913-
molecule_pos = atom_pos.gather(1, mol_atom_indices)
5914-
denoised_molecule_pos = denoised_atom_pos.gather(1, mol_atom_indices)
5903+
pde_atom_mask = batch_repeat_interleave(valid_molecule_atom_mask, molecule_atom_lens)
59155904

5916-
molecule_mask = valid_molecule_atom_mask
5917-
5918-
pde_gt_dist = torch.cdist(molecule_pos, molecule_pos, p=2)
5919-
pde_pred_dist = torch.cdist(
5920-
denoised_molecule_pos,
5921-
denoised_molecule_pos,
5922-
p=2,
5923-
)
5905+
pde_gt_dist = torch.cdist(atom_pos, atom_pos)
5906+
pde_pred_dist = torch.cdist(denoised_atom_pos, denoised_atom_pos)
59245907

59255908
# calculate pde labels as distance error binned to 64 (0 - 32A)
59265909

@@ -5929,8 +5912,8 @@ def forward(
59295912

59305913
# account for representative molecule atom missing from residue (-1 set on molecule_atom_indices field)
59315914

5932-
molecule_mask = to_pairwise_mask(molecule_mask)
5933-
pde_labels.masked_fill_(~molecule_mask, ignore)
5915+
pde_pairwise_atom_mask = to_pairwise_mask(pde_atom_mask)
5916+
pde_labels.masked_fill_(~pde_pairwise_atom_mask, ignore)
59345917

59355918
# determine plddt labels if possible
59365919

@@ -6016,7 +5999,17 @@ def forward(
60165999

60176000
confidence_weight = confidence_mask.float()
60186001

6019-
def cross_entropy_with_weight(logits, labels, weight, ignore_index: int):
6002+
@typecheck
6003+
def cross_entropy_with_weight(
6004+
logits: Float['b l ...'],
6005+
labels: Int['b ...'],
6006+
weight: Float[' b'],
6007+
mask: Bool['b ...'],
6008+
ignore_index: int
6009+
) -> Float['']:
6010+
6011+
labels = torch.where(mask, labels, ignore_index)
6012+
60206013
return F.cross_entropy(
60216014
einx.multiply('b ..., b -> b ...', logits, weight),
60226015
einx.multiply('b ..., b -> b ...', labels, weight.long()),
@@ -6028,32 +6021,28 @@ def cross_entropy_with_weight(logits, labels, weight, ignore_index: int):
60286021
f"pae_labels shape {pae_labels.shape[-1]} does not match "
60296022
f"ch_logits.pae shape {ch_logits.pae.shape[-1]}"
60306023
)
6031-
pae_labels = torch.where(label_pairwise_mask, pae_labels, ignore)
6032-
pae_loss = cross_entropy_with_weight(ch_logits.pae, pae_labels, confidence_weight, ignore)
6024+
pae_loss = cross_entropy_with_weight(ch_logits.pae, pae_labels, confidence_weight, label_pairwise_mask, ignore)
60336025

60346026
if exists(pde_labels):
60356027
assert pde_labels.shape[-1] == ch_logits.pde.shape[-1], (
60366028
f"pde_labels shape {pde_labels.shape[-1]} does not match "
60376029
f"ch_logits.pde shape {ch_logits.pde.shape[-1]}"
60386030
)
6039-
pde_labels = torch.where(to_pairwise_mask(mask), pde_labels, ignore)
6040-
pde_loss = cross_entropy_with_weight(ch_logits.pde, pde_labels, confidence_weight, ignore)
6031+
pde_loss = cross_entropy_with_weight(ch_logits.pde, pde_labels, confidence_weight, label_pairwise_mask, ignore)
60416032

60426033
if exists(plddt_labels):
60436034
assert plddt_labels.shape[-1] == ch_logits.plddt.shape[-1], (
60446035
f"plddt_labels shape {plddt_labels.shape[-1]} does not match "
60456036
f"ch_logits.plddt shape {ch_logits.plddt.shape[-1]}"
60466037
)
6047-
plddt_labels = torch.where(label_mask, plddt_labels, ignore)
6048-
plddt_loss = cross_entropy_with_weight(ch_logits.plddt, plddt_labels, confidence_weight, ignore)
6038+
plddt_loss = cross_entropy_with_weight(ch_logits.plddt, plddt_labels, confidence_weight, label_mask, ignore)
60496039

60506040
if exists(resolved_labels):
60516041
assert resolved_labels.shape[-1] == ch_logits.resolved.shape[-1], (
60526042
f"resolved_labels shape {resolved_labels.shape[-1]} does not match "
60536043
f"ch_logits.resolved shape {ch_logits.resolved.shape[-1]}"
60546044
)
6055-
resolved_labels = torch.where(label_mask, resolved_labels, ignore)
6056-
resolved_loss = cross_entropy_with_weight(ch_logits.resolved, resolved_labels, confidence_weight, ignore)
6045+
resolved_loss = cross_entropy_with_weight(ch_logits.resolved, resolved_labels, confidence_weight, label_mask, ignore)
60576046

60586047
confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
60596048

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.14"
3+
version = "0.3.15"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_confidence_head():
481481
)
482482

483483
assert logits.pae.shape[-1] == atom_seq_len
484-
assert logits.pde.shape[-1] == seq_len
484+
assert logits.pde.shape[-1] == atom_seq_len
485485

486486
assert logits.plddt.shape[-1] == atom_seq_len
487487
assert logits.resolved.shape[-1] == atom_seq_len
@@ -1000,7 +1000,7 @@ def test_compute_ranking_score():
10001000
is_modified_residue = torch.randint(0, 2, (batch_size, atom_seq_len))
10011001

10021002
pae_logits = torch.randn(batch_size, 64, atom_seq_len, atom_seq_len)
1003-
pde_logits = torch.randn(batch_size, 64, seq_len, seq_len)
1003+
pde_logits = torch.randn(batch_size, 64, atom_seq_len, atom_seq_len)
10041004
plddt_logits = torch.randn(batch_size, 50, atom_seq_len)
10051005
resolved_logits = torch.randint(0, 2, (batch_size, 2, atom_seq_len))
10061006

@@ -1043,7 +1043,6 @@ def test_compute_ranking_score():
10431043
assert atom_level_ptm_score.numel() == batch_size
10441044

10451045
def test_model_selection_score():
1046-
10471046
# mock inputs
10481047

10491048
batch_size = 2
@@ -1062,13 +1061,11 @@ def test_model_selection_score():
10621061

10631062
chain_length = [random.randint(seq_len // 4, seq_len //2)
10641063
for _ in range(batch_size)]
1065-
10661064
asym_id = torch.tensor([
10671065
[item for val, count in enumerate([chain_len, seq_len - chain_len]) for item in itertools.repeat(val, count)]
10681066
for chain_len in chain_length
10691067
]).long()
10701068

1071-
10721069
is_molecule_types = torch.zeros_like(asym_id)
10731070
is_molecule_types = torch.nn.functional.one_hot(is_molecule_types, 5).bool()
10741071

0 commit comments

Comments
 (0)