Skip to content

Commit 08d4c55

Browse files
authored
Add test_pdbinput_input unit test (#88)
* Update amino_acid_constants.py * Update biomolecule.py * Update dna_constants.py * Update ligand_constants.py * Update rna_constants.py * Update life.py * Update alphafold3.py * Update inputs.py * Update test_input.py * Create model_utils.py * Update model_utils.py
1 parent 2da0c5a commit 08d4c55

File tree

10 files changed

+1564
-342
lines changed

10 files changed

+1564
-342
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
n - molecule sequence length
6969
i - molecule sequence length (source)
7070
j - molecule sequence length (target)
71+
l - present (i.e., non-missing) atom sequence length
7172
m - atom sequence length
7273
nw - windowed sequence length
7374
d - feature dimension
@@ -3377,9 +3378,10 @@ def forward(
33773378
resolved_labels: Int['b n'] | None = None,
33783379
return_loss_breakdown = False,
33793380
return_loss: bool = None,
3381+
return_present_sampled_atoms: bool = False,
33803382
num_rollout_steps: int = 20,
33813383
rollout_show_tqdm_pbar: bool = False
3382-
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
3384+
) -> Float['b m 3'] | Float['l 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
33833385

33843386
atom_seq_len = atom_inputs.shape[-2]
33853387

@@ -3622,6 +3624,8 @@ def forward(
36223624

36233625
if exists(atom_mask):
36243626
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
3627+
if exists(missing_atom_mask) and return_present_sampled_atoms:
3628+
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
36253629

36263630
return sampled_atom_pos
36273631

alphafold3_pytorch/common/amino_acid_constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Amino acid constants used in AlphaFold."""
22

3-
import numpy as np
4-
53
from typing import Final
64

5+
import numpy as np
6+
77
# This mapping is used when we need to store atom data in a format that requires
88
# fixed atom data size for every residue (e.g. a numpy array).
99
# From: https://github.com/google-deepmind/alphafold/blob/f251de6613cb478207c732bf9627b1e853c99c2f/alphafold/common/residue_constants.py#L492C1-L497C2
@@ -59,6 +59,7 @@
5959
atom_types_set = set(atom_types)
6060
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6161
atom_type_num = len(atom_types) # := 37 + 10 null types := 47.
62+
res_rep_atom_index = 1 # The index of the atom used to represent the center of the residue.
6263

6364

6465
# This is the standard residue order when coding AA type as a number.
@@ -111,6 +112,7 @@
111112
"W": "TRP",
112113
"Y": "TYR",
113114
"V": "VAL",
115+
"X": "UNK",
114116
}
115117

116118
BIOMOLECULE_CHAIN: Final[str] = "polypeptide(L)"

alphafold3_pytorch/common/biomolecule.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"_struct_conn.",
4040
]
4141
MMCIF_PREFIXES_TO_DROP_POST_AF3 = MMCIF_PREFIXES_TO_DROP_POST_PARSING + [
42+
"_audit_author.",
4243
"_citation.",
4344
"_citation_author.",
4445
]
@@ -53,6 +54,10 @@ class Biomolecule:
5354
# amino acid residues.
5455
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
5556

57+
# Name of each residue-representative atom as a string,
58+
# which matches the number of (pseudo)residues (A.K.A. tokens).
59+
atom_name: np.ndarray # [num_res]
60+
5661
# Amino-acid or nucleotide type for each residue represented as an integer
5762
# between 0 and 31, where:
5863
# 20 represents the the unknown amino acid 'X';
@@ -124,6 +129,7 @@ def __add__(self, other: "Biomolecule") -> "Biomolecule":
124129
"""Merges two `Biomolecule` instances."""
125130
return Biomolecule(
126131
atom_positions=np.concatenate([self.atom_positions, other.atom_positions], axis=0),
132+
atom_name=np.concatenate([self.atom_name, other.atom_name], axis=0),
127133
restype=np.concatenate([self.restype, other.restype], axis=0),
128134
atom_mask=np.concatenate([self.atom_mask, other.atom_mask], axis=0),
129135
residue_index=np.concatenate([self.residue_index, other.residue_index], axis=0),
@@ -169,6 +175,7 @@ def subset_chains(self, subset_chain_ids: List[str]) -> "Biomolecule":
169175
chain_mask = np.isin(self.chain_index, list(subset_chain_index_mapping.keys()))
170176
return Biomolecule(
171177
atom_positions=self.atom_positions[chain_mask],
178+
atom_name=self.atom_name[chain_mask],
172179
restype=self.restype[chain_mask],
173180
atom_mask=self.atom_mask[chain_mask],
174181
residue_index=self.residue_index[chain_mask],
@@ -203,6 +210,7 @@ def repeat(self, coord: np.ndarray) -> "Biomolecule":
203210
"""Repeat a Biomolecule according to a (repeated) coordinate array."""
204211
return Biomolecule(
205212
atom_positions=coord.reshape(-1, 47, 3),
213+
atom_name=np.tile(self.atom_name, (coord.shape[0], 1)).reshape(-1),
206214
restype=np.tile(self.restype, (coord.shape[0], 1)).reshape(-1),
207215
atom_mask=np.tile(self.atom_mask, (coord.shape[0], 1, 1)).reshape(-1, 47),
208216
residue_index=np.tile(self.residue_index, (coord.shape[0], 1)).reshape(-1),
@@ -258,13 +266,18 @@ def get_ligand_atom_name(atom_name: str, atom_types_set: Set[str]) -> str:
258266
elif len(atom_name) == 2:
259267
return atom_name if atom_name in atom_types_set else atom_name[0]
260268
elif len(atom_name) == 3:
261-
return (
262-
atom_name
263-
if atom_name in atom_types_set
264-
else (
265-
atom_name[:2] if atom_name[:2] in atom_types_set else atom_name[0] + atom_name[2]
266-
)
267-
)
269+
if atom_name in atom_types_set:
270+
return atom_name
271+
elif atom_name[:2] in atom_types_set:
272+
return atom_name[:2]
273+
elif atom_name[1:] in atom_types_set:
274+
return atom_name[1:]
275+
elif atom_name[0] + atom_name[2] in atom_types_set:
276+
return atom_name[0] + atom_name[2]
277+
elif atom_name.split("H")[0] in atom_types_set:
278+
return atom_name.split("H")[0]
279+
else:
280+
return atom_name
268281
else:
269282
return atom_name
270283

@@ -334,6 +347,7 @@ def _from_mmcif_object(
334347
model = models[0]
335348

336349
atom_positions = []
350+
atom_names = []
337351
restype = []
338352
chemid = []
339353
chemtype = []
@@ -412,6 +426,9 @@ def _from_mmcif_object(
412426
chemid.append(res_chem_comp_details.id)
413427
chemtype.append(residue_constants.chemtype_num)
414428
atom_positions.append(pos)
429+
atom_names.append(
430+
residue_constants.atom_types[residue_constants.res_rep_atom_index]
431+
)
415432
atom_mask.append(mask)
416433
residue_index.append(res_index + 1)
417434
chain_ids.append(chain.id)
@@ -448,6 +465,7 @@ def _from_mmcif_object(
448465
atom_name = get_ligand_atom_name(atom.name, residue_constants.atom_types_set)
449466
if atom_name not in residue_constants.atom_types_set:
450467
atom_name = "ATM"
468+
atom_names.append(atom_name)
451469
pos[residue_constants.atom_order[atom_name]] = atom.coord
452470
mask[residue_constants.atom_order[atom_name]] = 1.0
453471
res_b_factors[residue_constants.atom_order[atom_name]] = atom.bfactor
@@ -505,6 +523,7 @@ def _from_mmcif_object(
505523

506524
return Biomolecule(
507525
atom_positions=np.array(atom_positions),
526+
atom_name=np.array(atom_names),
508527
restype=np.array(restype),
509528
atom_mask=np.array(atom_mask),
510529
residue_index=np.array(residue_index),

alphafold3_pytorch/common/dna_constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Deoxyribonucleic acid (DNA) constants used in AlphaFold."""
22

3-
import numpy as np
4-
53
from typing import Final
64

5+
import numpy as np
6+
77
from alphafold3_pytorch.common import amino_acid_constants, rna_constants
88

99
# This mapping is used when we need to store atom data in a format that requires
@@ -62,6 +62,7 @@
6262
atom_types_set = set(atom_types)
6363
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6464
atom_type_num = len(atom_types) # := 28 + 19 null types := 47.
65+
res_rep_atom_index = 11 # The index of the atom used to represent the center of the residue.
6566

6667

6768
# This is the standard residue order when coding DNA type as a number.
@@ -79,6 +80,7 @@
7980
"C": "DC",
8081
"G": "DG",
8182
"T": "DT",
83+
"X": "DN",
8284
}
8385

8486
BIOMOLECULE_CHAIN: Final[str] = "polydeoxyribonucleotide"
@@ -252,4 +254,5 @@ def _make_constants():
252254
compact_atom_idx = restype_name_to_compact_atom_names[resname].index(atomname)
253255
restype_atom47_to_compact_atom[restype, atomtype] = compact_atom_idx
254256

257+
255258
_make_constants()

alphafold3_pytorch/common/ligand_constants.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Ligand constants used in AlphaFold."""
22

3-
import numpy as np
4-
53
from typing import Final
64

5+
import numpy as np
6+
77
from alphafold3_pytorch.common import amino_acid_constants, dna_constants
88

99
# This mapping is used when we need to store atom data in a format that requires
@@ -58,9 +58,62 @@
5858
"ZN",
5959
"ATM",
6060
]
61+
element_types = [
62+
# NOTE: Taken from: https://github.com/baker-laboratory/RoseTTAFold-All-Atom/blob/c1fd92455be2a4133ad147242fc91cea35477282/rf2aa/chemical.py#L117C13-L126C18
63+
"Al",
64+
"As",
65+
"Au",
66+
"B",
67+
"Be",
68+
"Br",
69+
"C",
70+
"Ca",
71+
"Cl",
72+
"Co",
73+
"Cr",
74+
"Cu",
75+
"F",
76+
"Fe",
77+
"Hg",
78+
"I",
79+
"Ir",
80+
"K",
81+
"Li",
82+
"Mg",
83+
"Mn",
84+
"Mo",
85+
"N",
86+
"Ni",
87+
"O",
88+
"Os",
89+
"P",
90+
"Pb",
91+
"Pd",
92+
"Pr",
93+
"Pt",
94+
"Re",
95+
"Rh",
96+
"Ru",
97+
"S",
98+
"Sb",
99+
"Se",
100+
"Si",
101+
"Sn",
102+
"Tb",
103+
"Te",
104+
"U",
105+
"W",
106+
"V",
107+
"Y",
108+
"Zn",
109+
"ATM",
110+
]
61111
atom_types_set = set(atom_types)
62112
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
63113
atom_type_num = len(atom_types) # := 47.
114+
res_rep_atom_index = (
115+
len(atom_types) - 1
116+
) # := 46 # The index of the atom used to represent the center of a ligand pseudoresidue.
64117

65118

66119
# All ligand residues are mapped to the unknown amino acid type index (:= 20).
@@ -109,4 +162,5 @@ def _make_constants():
109162
compact_atom_idx = restype_name_to_compact_atom_names[resname].index(atomname)
110163
restype_atom47_to_compact_atom[restype, atomtype] = compact_atom_idx
111164

165+
112166
_make_constants()

alphafold3_pytorch/common/rna_constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Ribonucleic acid (RNA) constants used in AlphaFold."""
22

3-
import numpy as np
4-
53
from typing import Final
64

5+
import numpy as np
6+
77
from alphafold3_pytorch.common import amino_acid_constants
88

99
# This mapping is used when we need to store atom data in a format that requires
@@ -62,6 +62,7 @@
6262
atom_types_set = set(atom_types)
6363
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6464
atom_type_num = len(atom_types) # := 28 + 19 null types := 47.
65+
res_rep_atom_index = 12 # The index of the atom used to represent the center of the residue.
6566

6667

6768
# This is the standard residue order when coding RNA type as a number.
@@ -72,7 +73,7 @@
7273
restype_num = min_restype_num + len(restypes) # := 21 + 4 := 25.
7374

7475

75-
restype_1to3 = {"A": "A", "C": "C", "G": "G", "U": "U"}
76+
restype_1to3 = {"A": "A", "C": "C", "G": "G", "U": "U", "X": "N"}
7677

7778
BIOMOLECULE_CHAIN: Final[str] = "polyribonucleotide"
7879
POLYMER_CHAIN: Final[str] = "polymer"
@@ -244,4 +245,5 @@ def _make_constants():
244245
compact_atom_idx = restype_name_to_compact_atom_names[resname].index(atomname)
245246
restype_atom47_to_compact_atom[restype, atomtype] = compact_atom_idx
246247

248+
247249
_make_constants()

0 commit comments

Comments
 (0)