Skip to content

Commit 771cc4c

Browse files
committed
always atom resolution for trunk distogram
1 parent c357919 commit 771cc4c

File tree

4 files changed

+16
-35
lines changed

4 files changed

+16
-35
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,7 +3392,6 @@ def __init__(
33923392
dim_pairwise = 128,
33933393
num_dist_bins = 38,
33943394
dim_atom = 128,
3395-
atom_resolution = False
33963395
):
33973396
super().__init__()
33983397

@@ -3404,10 +3403,7 @@ def __init__(
34043403
# atom resolution
34053404
# for now, just embed per atom distances, sum to atom features, project to pairwise dimension
34063405

3407-
self.atom_resolution = atom_resolution
3408-
3409-
if atom_resolution:
3410-
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
3406+
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
34113407

34123408
# tensor typing
34133409

@@ -3417,16 +3413,12 @@ def __init__(
34173413
def forward(
34183414
self,
34193415
pairwise_repr: Float['b n n d'],
3420-
molecule_atom_lens: Int['b n'] | None = None,
3421-
atom_feats: Float['b m {self.da}'] | None = None,
3416+
molecule_atom_lens: Int['b n'],
3417+
atom_feats: Float['b m {self.da}'],
34223418
) -> Float['b l n n'] | Float['b l m m']:
34233419

3424-
if self.atom_resolution:
3425-
assert exists(molecule_atom_lens)
3426-
assert exists(atom_feats)
3427-
3428-
pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
3429-
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
3420+
pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
3421+
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
34303422

34313423
logits = self.to_distogram_logits(symmetrize(pairwise_repr))
34323424

@@ -4989,7 +4981,6 @@ def __init__(
49894981
lddt_mask_other_cutoff = 15.,
49904982
augment_kwargs: dict = dict(),
49914983
stochastic_frame_average = False,
4992-
distogram_atom_resolution = False,
49934984
checkpoint_input_embedding = False,
49944985
checkpoint_trunk_pairformer = False,
49954986
checkpoint_distogram_head = False,
@@ -5171,13 +5162,10 @@ def __init__(
51715162

51725163
assert len(distance_bins_tensor) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51735164

5174-
self.distogram_atom_resolution = distogram_atom_resolution
5175-
51765165
self.distogram_head = DistogramHead(
51775166
dim_pairwise = dim_pairwise,
51785167
dim_atom = dim_atom,
51795168
num_dist_bins = num_dist_bins,
5180-
atom_resolution = distogram_atom_resolution,
51815169
)
51825170

51835171
# lddt related
@@ -5679,15 +5667,7 @@ def forward(
56795667
if not exists(distance_labels) and atom_pos_given and exists(distogram_atom_indices):
56805668

56815669
distogram_pos = atom_pos
5682-
5683-
if not self.distogram_atom_resolution:
5684-
# molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
5685-
5686-
distogram_atom_indices = repeat(distogram_atom_indices, 'b n -> b n c', c = distogram_pos.shape[-1])
5687-
molecule_pos = distogram_pos = distogram_pos.gather(1, distogram_atom_indices)
5688-
distogram_mask = valid_distogram_mask
5689-
else:
5690-
distogram_mask = atom_mask
5670+
distogram_mask = atom_mask
56915671

56925672
distogram_dist = torch.cdist(distogram_pos, distogram_pos, p = 2)
56935673
distance_labels = distance_to_bins(distogram_dist, self.distance_bins)
@@ -5700,9 +5680,7 @@ def forward(
57005680
if exists(distance_labels):
57015681

57025682
distogram_mask = pairwise_mask
5703-
5704-
if self.distogram_atom_resolution:
5705-
distogram_mask = to_pairwise_mask(atom_mask)
5683+
distogram_mask = to_pairwise_mask(atom_mask)
57065684

57075685
distance_labels = torch.where(distogram_mask, distance_labels, ignore)
57085686

alphafold3_pytorch/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __getitem__(self, idx):
7777
molecule_atom_indices = molecule_atom_lens - 1
7878
distogram_atom_indices = molecule_atom_lens - 1
7979

80-
distance_labels = torch.randint(0, 37, (seq_len, seq_len))
80+
distance_labels = torch.randint(0, 37, (atom_seq_len, atom_seq_len))
8181
resolved_labels = torch.randint(0, 2, (atom_seq_len,))
8282

8383
majority_asym_id = asym_id.mode().values.item()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.12"
3+
version = "0.3.14"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,27 +511,31 @@ def test_input_embedder():
511511
)
512512

513513
def test_distogram_head():
514+
molecule_atom_lens = torch.ones((2, 16)).long()
515+
atom_feats = torch.randn(2, 16, 128)
514516
pairwise_repr = torch.randn(2, 16, 16, 128)
515517

516518
distogram_head = DistogramHead(dim_pairwise = 128)
517519

518-
logits = distogram_head(pairwise_repr)
520+
logits = distogram_head(
521+
pairwise_repr,
522+
atom_feats = atom_feats,
523+
molecule_atom_lens = molecule_atom_lens
524+
)
519525

520526
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
521527
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
522528
@pytest.mark.parametrize('missing_atoms', (True, False))
523529
@pytest.mark.parametrize('calculate_pae', (True, False))
524530
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
525531
@pytest.mark.parametrize('num_molecule_mods', (0, 4))
526-
@pytest.mark.parametrize('distogram_atom_resolution', (True, False))
527532
def test_alphafold3(
528533
window_atompair_inputs: bool,
529534
stochastic_frame_average: bool,
530535
missing_atoms: bool,
531536
calculate_pae: bool,
532537
atom_transformer_intramolecular_attn: bool,
533538
num_molecule_mods: int,
534-
distogram_atom_resolution: bool
535539
):
536540
seq_len = 16
537541
atom_seq_len = 32
@@ -622,7 +626,6 @@ def test_alphafold3(
622626
)
623627
),
624628
stochastic_frame_average = stochastic_frame_average,
625-
distogram_atom_resolution = distogram_atom_resolution
626629
)
627630

628631
loss, breakdown = alphafold3(

0 commit comments

Comments
 (0)