Skip to content

Commit 37f8108

Browse files
committed
add is_molecule_mod to AtomInput and BatchAtomInput
1 parent 241318c commit 37f8108

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
IS_LIGAND,
5555
IS_METAL_ION,
5656
NUM_MOLECULE_IDS,
57+
DEFAULT_NUM_MOLECULE_MODS,
5758
ADDITIONAL_MOLECULE_FEATS
5859
)
5960

@@ -4519,6 +4520,7 @@ def __init__(
45194520
self.w = atoms_per_window
45204521
self.dapi = self.dim_atompair_inputs
45214522
self.dai = self.dim_atom_inputs
4523+
self.num_mods = num_molecule_mods
45224524

45234525
@property
45244526
def device(self):
@@ -4605,7 +4607,7 @@ def forward(
46054607
additional_token_feats: Float['b n {self.dim_additional_token_feats}'] | None = None,
46064608
atom_ids: Int['b m'] | None = None,
46074609
atompair_ids: Int['b m m'] | Int['b nw {self.w} {self.w*2}'] | None = None,
4608-
is_molecule_mod: Bool['b n num_mods'] | None = None,
4610+
is_molecule_mod: Bool['b n {self.num_mods}'] | None = None,
46094611
atom_mask: Bool['b m'] | None = None,
46104612
missing_atom_mask: Bool['b m'] | None = None,
46114613
atom_parent_ids: Int['b m'] | None = None,

alphafold3_pytorch/inputs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
8080
NUM_MOLECULE_IDS = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 2
8181

82+
DEFAULT_NUM_MOLECULE_MODS = 5
8283
ADDITIONAL_MOLECULE_FEATS = 5
8384

8485
CCD_COMPONENTS_FILEPATH = os.path.join('data', 'ccd_data', 'components.cif')
@@ -140,7 +141,8 @@ class AtomInput:
140141
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
141142
additional_molecule_feats: Int[f'n {ADDITIONAL_MOLECULE_FEATS}']
142143
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
143-
additional_token_feats: Float[f'n dtf'] | None = None
144+
is_molecule_mod: Bool['n num_mods'] | None = None
145+
additional_token_feats: Float['n dtf'] | None = None
144146
templates: Float['t n n dt'] | None = None
145147
msa: Float['s n dm'] | None = None
146148
token_bonds: Bool['n n'] | None = None
@@ -171,7 +173,8 @@ class BatchedAtomInput:
171173
atompair_inputs: Float['b m m dapi'] | Float['b nw w (w*2) dapi']
172174
additional_molecule_feats: Int[f'b n {ADDITIONAL_MOLECULE_FEATS}']
173175
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}']
174-
additional_token_feats: Float[f'b n dtf'] | None = None
176+
is_molecule_mod: Bool['b n num_mods'] | None = None
177+
additional_token_feats: Float['b n dtf'] | None = None
175178
templates: Float['b t n n dt'] | None = None
176179
msa: Float['b s n dm'] | None = None
177180
token_bonds: Bool['b n n'] | None = None
@@ -326,6 +329,7 @@ class MoleculeInput:
326329
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
327330
src_tgt_atom_indices: Int['n 2']
328331
token_bonds: Bool['n n']
332+
is_molecule_mod: Bool['n num_mods'] | None = None
329333
molecule_atom_indices: List[int | None] | None = None
330334
distogram_atom_indices: List[int | None] | None = None
331335
missing_atom_indices: List[Int[' _'] | None] | None = None

alphafold3_pytorch/mocks.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import torch
55
from torch.utils.data import Dataset
66
from alphafold3_pytorch import AtomInput
7-
from alphafold3_pytorch.inputs import IS_MOLECULE_TYPES
7+
8+
from alphafold3_pytorch.inputs import (
9+
IS_MOLECULE_TYPES,
10+
DEFAULT_NUM_MOLECULE_MODS
11+
)
812

913
# mock dataset
1014

@@ -13,11 +17,13 @@ def __init__(
1317
self,
1418
data_length,
1519
max_seq_len = 16,
16-
atoms_per_window = 4
20+
atoms_per_window = 4,
21+
has_molecule_mods = False
1722
):
1823
self.data_length = data_length
1924
self.max_seq_len = max_seq_len
2025
self.atoms_per_window = atoms_per_window
26+
self.has_molecule_mods = has_molecule_mods
2127

2228
def __len__(self):
2329
return self.data_length
@@ -33,6 +39,11 @@ def __getitem__(self, idx):
3339
additional_molecule_feats = torch.randint(0, 2, (seq_len, 5))
3440
additional_token_feats = torch.randn(seq_len, 2)
3541
is_molecule_types = torch.randint(0, 2, (seq_len, IS_MOLECULE_TYPES)).bool()
42+
43+
is_molecule_mod = None
44+
if self.has_molecule_mods:
45+
is_molecule_mod = torch.rand((seq_len, DEFAULT_NUM_MOLECULE_MODS)) < 0.05
46+
3647
molecule_ids = torch.randint(0, 32, (seq_len,))
3748
token_bonds = torch.randint(0, 2, (seq_len, seq_len)).bool()
3849

@@ -65,6 +76,7 @@ def __getitem__(self, idx):
6576
additional_molecule_feats = additional_molecule_feats,
6677
additional_token_feats = additional_token_feats,
6778
is_molecule_types = is_molecule_types,
79+
is_molecule_mod = is_molecule_mod,
6880
templates = templates,
6981
template_mask = template_mask,
7082
msa = msa,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.74"
3+
version = "0.2.75"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)