Skip to content

Commit c461701

Browse files
committed
use registered pde_breaks, and some driveby cleanup
1 parent 54da357 commit c461701

File tree

4 files changed

+27
-30
lines changed

4 files changed

+27
-30
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def mean_pool_with_lens(
292292
def repeat_consecutive_with_lens(
293293
feats: Float['b n ...'] | Bool['b n ...'] | Bool['b n'] | Int['b n'],
294294
lens: Int['b n'],
295-
mask_value: Optional[float | int | bool] = None,
295+
mask_value: float | int | bool | None = None,
296296
) -> Float['b m ...'] | Bool['b m ...'] | Bool['b m'] | Int['b m']:
297297

298298
device, dtype = feats.device, feats.dtype
@@ -3227,12 +3227,11 @@ def compute_ptm(
32273227
def compute_pde(
32283228
self,
32293229
logits: Float['b pde n n'],
3230-
breaks: Float[' pde_break'],
32313230
tok_repr_atm_mask: Bool[' b n'],
32323231
)-> Float[' b n n']:
32333232

32343233
logits = rearrange(logits, 'b pde i j -> b i j pde')
3235-
bin_centers = self._calculate_bin_centers(breaks.to(logits.device))
3234+
bin_centers = self._calculate_bin_centers(self.pde_breaks)
32363235
probs = F.softmax(logits, dim=-1)
32373236

32383237
pde = einsum(probs, bin_centers, 'b i j pde, pde -> b i j ')
@@ -3495,7 +3494,7 @@ def compute_modified_residue_score(
34953494
# model selection
34963495
def get_cid_molecule_type(
34973496
cid: int,
3498-
asym_id: Int['n'],
3497+
asym_id: Int[' n'],
34993498
is_molecule_types: Bool['n {IS_MOLECULE_TYPES}'],
35003499
return_one_hot: bool = False,
35013500
) -> int | Bool[' {IS_MOLECULE_TYPES}']:
@@ -3505,17 +3504,15 @@ def get_cid_molecule_type(
35053504
"""
35063505

35073506
cid_is_molecule_types = is_molecule_types[asym_id == cid]
3508-
valid = torch.all(
3509-
einx.equal('b i, i -> b i',
3510-
cid_is_molecule_types,
3511-
cid_is_molecule_types[0])
3512-
)
3507+
molecule_type, rest_molecule_type = cid_is_molecule_types[0], cid_is_molecule_types[1:]
3508+
3509+
valid = einx.equal('b i, i -> b i', rest_molecule_type, molecule_type).all()
3510+
35133511
assert valid, f"Ambiguous molecule types for chain {cid}"
35143512

3515-
if return_one_hot:
3516-
molecule_type = cid_is_molecule_types[0]
3517-
else:
3518-
molecule_type = cid_is_molecule_types[0].int().argmax().item()
3513+
if not return_one_hot:
3514+
molecule_type = molecule_type.int().argmax().item()
3515+
35193516
return molecule_type
35203517

35213518
class ComputeModelSelectionScore(Module):
@@ -3525,14 +3522,17 @@ def __init__(
35253522
eps: float = 1e-8,
35263523
dist_breaks: Float[' dist_break'] = torch.linspace(2.3125,21.6875,63,),
35273524
nucleic_acid_cutoff: float = 30.0,
3528-
other_cutoff: float = 15.0
3525+
other_cutoff: float = 15.0,
3526+
contact_mask_threshold: float = 8.0
35293527
):
35303528

35313529
super().__init__()
35323530
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
35333531
self.eps = eps
35343532
self.nucleic_acid_cutoff = nucleic_acid_cutoff
35353533
self.other_cutoff = other_cutoff
3534+
self.contact_mask_threshold = contact_mask_threshold
3535+
35363536
self.register_buffer('dist_breaks', dist_breaks)
35373537

35383538
def compute_gpde(
@@ -3548,13 +3548,14 @@ def compute_gpde(
35483548
tok_repr_atm_mask: [b n] true if token representation atoms exists
35493549
"""
35503550

3551-
pde = self.compute_confidence_score.compute_pde(
3552-
pde_logits, self.compute_confidence_score.pde_breaks, tok_repr_atm_mask)
3551+
pde = self.compute_confidence_score.compute_pde(pde_logits, tok_repr_atm_mask)
35533552

35543553
dist_logits = rearrange(dist_logits, 'b dist i j -> b i j dist')
35553554
dist_probs = F.softmax(dist_logits, dim=-1)
3556-
contact_mask = dist_breaks < 8.0
3557-
contact_mask = torch.cat([contact_mask, torch.zeros([1], device=dist_logits.device)]).bool()
3555+
3556+
contact_mask = dist_breaks < self.contact_mask_threshold
3557+
contact_mask = F.pad(contact_mask, (0, 1), value = True)
3558+
35583559
contact_prob = einx.where(
35593560
' dist, b i j dist, -> b i j dist',
35603561
contact_mask, dist_probs, 0.
@@ -3577,7 +3578,7 @@ def compute_lddt(
35773578
is_rna: Bool['b m'],
35783579
pairwise_mask: Bool['b m m'],
35793580
coords_mask: Bool['b m'] | None = None,
3580-
) -> Float['b']:
3581+
) -> Float[' b']:
35813582
"""
35823583
pred_coords: predicted coordinates
35833584
true_coords: true coordinates
@@ -3635,7 +3636,7 @@ def compute_chain_pair_lddt(
36353636
true_coords: Float['b m 3'] | Float['m 3'],
36363637
is_molecule_types: Int['b m {IS_MOLECULE_TYPES}'] | Int['m {IS_MOLECULE_TYPES}'],
36373638
coords_mask: Bool['b m'] | Bool [' m'] | None = None,
3638-
) -> Float['b']:
3639+
) -> Float[' b']:
36393640
"""
36403641
36413642
plddt between atoms maked by asym_mask_a and asym_mask_b

alphafold3_pytorch/mocks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
import torch
55
from torch.utils.data import Dataset
66
from alphafold3_pytorch import AtomInput
7-
8-
from alphafold3_pytorch.inputs import (
9-
IS_MOLECULE_TYPES,
10-
AtomInput
11-
)
7+
from alphafold3_pytorch.inputs import IS_MOLECULE_TYPES
128

139
# mock dataset
1410

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.50"
3+
version = "0.2.51"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import os
22
os.environ['TYPECHECK'] = 'True'
33

4-
import torch
54
import pytest
5+
import random
6+
import itertools
67
from pathlib import Path
78

9+
import torch
10+
811
from alphafold3_pytorch import (
912
SmoothLDDTLoss,
1013
WeightedRigidAlign,
@@ -989,9 +992,6 @@ def test_compute_ranking_score():
989992

990993
def test_model_selection_score():
991994

992-
import random
993-
import itertools
994-
995995
# mock inputs
996996

997997
batch_size = 2

0 commit comments

Comments
 (0)