Skip to content

Commit 846d7af

Browse files
authored
Add MSA dataloading bug fixes (#179)
* Update data_pipeline.py * Update msa_parsing.py * Update inputs.py
1 parent dde0405 commit 846d7af

File tree

3 files changed

+111
-73
lines changed

3 files changed

+111
-73
lines changed

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,17 @@ def make_msa_mask(features: FeatureDict) -> FeatureDict:
5252
@typecheck
5353
def make_msa_features(
5454
msas: Dict[str, msa_parsing.Msa | None],
55-
chain_id_to_chem_types: Dict[str, List[int]],
55+
chain_id_to_residue: Dict[str, Dict[str, List[int]]],
56+
ligand_chemtype_index: int = 3,
5657
raise_missing_exception: bool = False,
5758
) -> FeatureDict:
5859
"""
5960
Construct a feature dictionary of MSA features.
6061
From: https://github.com/aqlaboratory/openfold/blob/6f63267114435f94ac0604b6d89e82ef45d94484/openfold/data/data_pipeline.py#L224
6162
6263
:param msas: The mapping of chain IDs to lists of (optional) MSAs for each chain.
63-
:param chain_id_to_chem_types: The mapping of chain IDs to residue (integer) chemical types.
64+
:param chain_id_to_residue: The mapping of chain IDs to residue information.
65+
:param ligand_index: The index of the ligand in the chemical type list.
6466
:param raise_missing_exception: Whether to raise an exception if no MSAs are provided for any chain.
6567
:return: The MSA feature dictionary.
6668
"""
@@ -70,7 +72,7 @@ def make_msa_features(
7072
# Infer MSA metadata.
7173
max_alignments = 1
7274
for msa in msas.values():
73-
if exists(msa.sequences) and exists(msa.sequences[0]):
75+
if exists(msa) and exists(msa.sequences) and exists(msa.sequences[0]):
7476
max_alignments = max(max_alignments, len(msa.sequences) if msa else 1)
7577

7678
# Collect MSAs.
@@ -84,10 +86,20 @@ def make_msa_features(
8486
species_ids = []
8587
seen_sequences = set()
8688

87-
chain_chem_types = chain_id_to_chem_types[chain_id]
88-
num_res = len(chain_chem_types)
89+
chain_chemtype = chain_id_to_residue[chain_id]["chemtype"]
90+
chain_residue_index = chain_id_to_residue[chain_id]["residue_index"]
8991

90-
msa_residue_constants = get_residue_constants(msa.msa_type.replace("protein", "peptide"))
92+
num_res = len(chain_chemtype)
93+
assert num_res == len(chain_residue_index), (
94+
f"Residue features count mismatch for chain {chain_id}: "
95+
f"{num_res} != {len(chain_residue_index)}"
96+
)
97+
98+
msa_residue_constants = (
99+
get_residue_constants(msa.msa_type.replace("protein", "peptide"))
100+
if exists(msa)
101+
else None
102+
)
91103

92104
gap_ids = [[GAP_ID] * num_res]
93105
deletion_values = [[0] * num_res]
@@ -98,9 +110,11 @@ def make_msa_features(
98110
elif not msa:
99111
# Pad the MSA to the maximum number of alignments
100112
# if the chain does not have any associated alignments.
101-
int_msa_list.append(gap_ids * max_alignments)
102-
deletion_matrix_list.append(deletion_values * max_alignments)
103-
species_ids_list.append(species * max_alignments)
113+
int_msa_list.append(torch.tensor(gap_ids * max_alignments, dtype=torch.long))
114+
deletion_matrix_list.append(
115+
torch.tensor(deletion_values * max_alignments, dtype=torch.float32)
116+
)
117+
species_ids_list.append(np.array(species * max_alignments, dtype=object))
104118
continue
105119

106120
for sequence_index, sequence in enumerate(msa.sequences):
@@ -109,15 +123,26 @@ def make_msa_features(
109123
seen_sequences.add(sequence)
110124

111125
# Convert the MSA to integers while handling
112-
# ligands and (unmappable) modified polymer residues.
126+
# ligands and modified polymer residues.
113127
msa_res_types = []
114128
msa_deletion_values = []
115129

116-
polymer_res_index = 0
130+
polymer_residue_index = -1
117131

118-
for chem_type in chain_chem_types:
119-
is_ligand = chem_type == 3
120-
chem_residue_constants = get_residue_constants(res_chem_index=chem_type)
132+
for idx, (chemtype, residue_index) in enumerate(
133+
zip(chain_chemtype, chain_residue_index)
134+
):
135+
is_polymer = chemtype < ligand_chemtype_index
136+
is_ligand = not is_polymer
137+
138+
chem_residue_constants = get_residue_constants(res_chem_index=chemtype)
139+
140+
# NOTE: For modified polymer residues, we only increment the polymer residue index
141+
# when the current (atomized) modified polymer residue's atom sequence ends.
142+
increment_index = (
143+
0 < idx < num_res and chain_residue_index[idx - 1] != residue_index
144+
)
145+
polymer_residue_index += 1 if is_polymer and (idx == 0 or increment_index) else 0
121146

122147
if is_ligand:
123148
# NOTE: For ligands, we use the unknown amino acid type.
@@ -131,18 +156,20 @@ def make_msa_features(
131156
if chem_residue_constants != msa_residue_constants:
132157
msa_res_type = chem_residue_constants.restype_num
133158
else:
134-
res = sequence[polymer_res_index]
159+
res = sequence[polymer_residue_index]
135160
msa_res_type = msa_residue_constants.MSA_CHAR_TO_ID.get(
136161
res, msa_residue_constants.restype_num
137162
)
138163

139-
msa_deletion_value = msa.deletion_matrix[sequence_index][polymer_res_index]
140-
141-
polymer_res_index += 1
164+
msa_deletion_value = msa.deletion_matrix[sequence_index][polymer_residue_index]
142165

143166
msa_res_types.append(msa_res_type)
144167
msa_deletion_values.append(msa_deletion_value)
145168

169+
assert polymer_residue_index + 1 == len(
170+
sequence
171+
), f"Polymer residue index length mismatch for MSA chain {chain_id}: {polymer_residue_index + 1} != {len(sequence)}"
172+
146173
int_msa.append(msa_res_types)
147174
deletion_matrix.append(msa_deletion_values)
148175

alphafold3_pytorch/data/msa_parsing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __len__(self):
121121

122122
def truncate(self, max_seqs: int):
123123
"""Truncates the MSA to the first `max_seqs` sequences."""
124+
max_seqs = min(len(self.sequences), max_seqs)
124125
return Msa(
125126
sequences=self.sequences[:max_seqs],
126127
deletion_matrix=self.deletion_matrix[:max_seqs],
@@ -130,6 +131,7 @@ def truncate(self, max_seqs: int):
130131

131132
def random_truncate(self, max_seqs: int):
132133
"""Truncates the MSA to a random range of `max_seqs` sequences."""
134+
max_seqs = min(len(self.sequences), max_seqs)
133135
start = random.randint(0, len(self.sequences) - max_seqs) # nosec
134136
return Msa(
135137
sequences=self.sequences[start : start + max_seqs],

alphafold3_pytorch/inputs.py

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,7 +2264,7 @@ def find_mismatched_symmetry(
22642264
def load_msa_from_msa_dir(
22652265
msa_dir: str | None,
22662266
file_id: str,
2267-
chain_id_to_chem_types: Dict[str, List[int]],
2267+
chain_id_to_residue: Dict[str, Dict[str, List[int]]],
22682268
max_msas_per_chain: int | None = None,
22692269
randomly_truncate: bool = True,
22702270
raise_missing_exception: bool = False,
@@ -2279,16 +2279,16 @@ def load_msa_from_msa_dir(
22792279
return {}
22802280

22812281
msas = {}
2282-
for chain_id in chain_id_to_chem_types:
2282+
for chain_id in chain_id_to_residue:
22832283
msa_fpaths = glob.glob(os.path.join(msa_dir, f"{file_id}{chain_id}_*.a3m"))
22842284

22852285
if not msa_fpaths:
22862286
msas[chain_id] = None
22872287
continue
22882288

22892289
# NOTE: A single chain-specific MSA file contains alignments for all polymer residues in the chain,
2290-
# but the ligand (and some "unmappable" modified polymer residues) are not included in the MSA file
2291-
# and therefore must be manually inserted into the MSAs as unknown amino acid residues.
2290+
# but the chain's ligands are not included in the MSA file and therefore must be manually inserted
2291+
# into the MSAs as unknown amino acid residues.
22922292
assert len(msa_fpaths) == 1, (
22932293
f"{len(msa_fpaths)} MSA files found for chain {chain_id} of file {file_id}. "
22942294
"Please ensure that one MSA file is present for each chain."
@@ -2310,7 +2310,7 @@ def load_msa_from_msa_dir(
23102310
)
23112311
msas[chain_id] = msa
23122312

2313-
features = make_msa_features(msas, chain_id_to_chem_types)
2313+
features = make_msa_features(msas, chain_id_to_residue)
23142314
features = make_msa_mask(features)
23152315

23162316
return features
@@ -2606,10 +2606,12 @@ def pdb_input_to_molecule_input(
26062606

26072607
# concat for all of additional_molecule_feats
26082608

2609+
# NOTE: `Biomolecule.residue_index` is 1-based originally
2610+
residue_index = torch.from_numpy(biomol.residue_index) - 1
2611+
26092612
additional_molecule_feats = torch.stack(
26102613
(
2611-
# NOTE: `Biomolecule.residue_index` is 1-based originally
2612-
torch.from_numpy(biomol.residue_index) - 1,
2614+
residue_index,
26132615
torch.arange(num_tokens),
26142616
torch.from_numpy(biomol.chain_index),
26152617
entity_ids,
@@ -2790,15 +2792,63 @@ def pdb_input_to_molecule_input(
27902792
num_present_atoms = mol_total_atoms - num_missing_atom_indices
27912793
assert num_present_atoms == int(biomol.atom_mask.sum())
27922794

2795+
# handle `atom_indices_for_frame` for the PAE
2796+
2797+
atom_indices_for_frame = tensor(
2798+
[default(indices, (-1, -1, -1)) for indices in atom_indices_for_frame]
2799+
)
2800+
2801+
# build offsets for all indices
2802+
2803+
# derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
2804+
atoms_per_molecule = tensor([mol.GetNumAtoms() for mol in molecules])
2805+
ones = torch.ones_like(atoms_per_molecule)
2806+
2807+
# `is_molecule_mod` can either be
2808+
# 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
2809+
# 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
2810+
is_molecule_mod = tensor(is_molecule_mod)
2811+
is_molecule_any_mod = False
2812+
2813+
if is_molecule_mod.ndim == 2:
2814+
is_molecule_any_mod = is_molecule_mod[unique_chain_residue_indices].any(dim=-1)
2815+
else:
2816+
is_molecule_any_mod = is_molecule_mod[unique_chain_residue_indices]
2817+
2818+
# get `one_token_per_atom`
2819+
# default to what the paper did, which is ligands and any modified biomolecule
2820+
is_ligand = is_molecule_types[unique_chain_residue_indices][..., IS_LIGAND_INDEX]
2821+
one_token_per_atom = is_ligand | is_molecule_any_mod
2822+
2823+
assert len(molecules) == len(one_token_per_atom)
2824+
2825+
# derive the number of repeats needed to expand molecule lengths to token lengths
2826+
token_repeats = torch.where(one_token_per_atom, atoms_per_molecule, ones)
2827+
2828+
# craft offsets for all atom indices
2829+
atom_indices_offsets = repeat_interleave(
2830+
exclusive_cumsum(atoms_per_molecule), token_repeats, dim=0
2831+
)
2832+
2833+
# offset only positive atom indices
2834+
distogram_atom_indices = offset_only_positive(distogram_atom_indices, atom_indices_offsets)
2835+
molecule_atom_indices = offset_only_positive(molecule_atom_indices, atom_indices_offsets)
2836+
atom_indices_for_frame = offset_only_positive(
2837+
atom_indices_for_frame, atom_indices_offsets[..., None]
2838+
)
2839+
27932840
# retrieve multiple sequence alignments (MSAs) for each chain
27942841
# NOTE: if they are not locally available, `Nones` will be used
27952842
msa_chain_ids = list(dict.fromkeys(biomol.chain_id.tolist()))
2796-
chain_id_to_chem_types = {
2797-
chain_id: biomol.chemtype[biomol.chain_id == chain_id].tolist()
2843+
chain_id_to_residue = {
2844+
chain_id: {
2845+
"chemtype": biomol.chemtype[biomol.chain_id == chain_id].tolist(),
2846+
"residue_index": residue_index[biomol.chain_id == chain_id].tolist(),
2847+
}
27982848
for chain_id in msa_chain_ids
27992849
}
28002850
msa_features = load_msa_from_msa_dir(
2801-
i.msa_dir, file_id, chain_id_to_chem_types, max_msas_per_chain=i.max_msas_per_chain
2851+
i.msa_dir, file_id, chain_id_to_residue, max_msas_per_chain=i.max_msas_per_chain
28022852
)
28032853

28042854
msa = msa_features.get("msa")
@@ -2817,6 +2867,10 @@ def pdb_input_to_molecule_input(
28172867
num_msas = len(msa) if exists(msa) else 1
28182868

28192869
if exists(msa):
2870+
assert (
2871+
msa.shape[-1] == num_tokens
2872+
), f"The number of tokens in the MSA ({msa.shape[-1]}) does not match the number of tokens in the biomolecule ({num_tokens}). "
2873+
28202874
has_deletion = torch.clip(msa_features["deletion_matrix"], 0.0, 1.0)
28212875
deletion_value = torch.atan(msa_features["deletion_matrix"] / 3.0) * (2.0 / torch.pi)
28222876

@@ -2883,51 +2937,6 @@ def pdb_input_to_molecule_input(
28832937
is_resolved_label = ((resolution >= 0.1) & (resolution <= 3.0)).item()
28842938
resolved_labels = torch.full((num_atoms,), is_resolved_label, dtype=torch.long)
28852939

2886-
# handle `atom_indices_for_frame` for the PAE
2887-
2888-
atom_indices_for_frame = tensor(
2889-
[default(indices, (-1, -1, -1)) for indices in atom_indices_for_frame]
2890-
)
2891-
2892-
# build offsets for all indices
2893-
2894-
# derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
2895-
atoms_per_molecule = tensor([mol.GetNumAtoms() for mol in molecules])
2896-
ones = torch.ones_like(atoms_per_molecule)
2897-
2898-
# `is_molecule_mod` can either be
2899-
# 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
2900-
# 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
2901-
is_molecule_mod = tensor(is_molecule_mod)
2902-
is_molecule_any_mod = False
2903-
2904-
if is_molecule_mod.ndim == 2:
2905-
is_molecule_any_mod = is_molecule_mod[unique_chain_residue_indices].any(dim=-1)
2906-
else:
2907-
is_molecule_any_mod = is_molecule_mod[unique_chain_residue_indices]
2908-
2909-
# get `one_token_per_atom`
2910-
# default to what the paper did, which is ligands and any modified biomolecule
2911-
is_ligand = is_molecule_types[unique_chain_residue_indices][..., IS_LIGAND_INDEX]
2912-
one_token_per_atom = is_ligand | is_molecule_any_mod
2913-
2914-
assert len(molecules) == len(one_token_per_atom)
2915-
2916-
# derive the number of repeats needed to expand molecule lengths to token lengths
2917-
token_repeats = torch.where(one_token_per_atom, atoms_per_molecule, ones)
2918-
2919-
# craft offsets for all atom indices
2920-
atom_indices_offsets = repeat_interleave(
2921-
exclusive_cumsum(atoms_per_molecule), token_repeats, dim=0
2922-
)
2923-
2924-
# offset only positive atom indices
2925-
distogram_atom_indices = offset_only_positive(distogram_atom_indices, atom_indices_offsets)
2926-
molecule_atom_indices = offset_only_positive(molecule_atom_indices, atom_indices_offsets)
2927-
atom_indices_for_frame = offset_only_positive(
2928-
atom_indices_for_frame, atom_indices_offsets[..., None]
2929-
)
2930-
29312940
# create molecule input
29322941

29332942
molecule_input = MoleculeInput(

0 commit comments

Comments
 (0)