Skip to content

Commit 336ba7a

Browse files
committed
feat: add apo to holo mask
1 parent 56b2d13 commit 336ba7a

File tree

4 files changed

+78
-36
lines changed

4 files changed

+78
-36
lines changed

docs/examples/5_dataset_and_loader.ipynb

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@
8888
"val_data = val_dataset[1]"
8989
]
9090
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"val_data['features_and_coords']['apo_features']"
98+
]
99+
},
91100
{
92101
"cell_type": "code",
93102
"execution_count": null,
@@ -125,8 +134,24 @@
125134
"outputs": [],
126135
"source": [
127136
"for k, v in test_torch[\"features_and_coords\"].items():\n",
128-
" print(k, v.shape)"
137+
" print(k, v)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"test_torch[\"features_and_coords\"].keys()"
129147
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": null,
152+
"metadata": {},
153+
"outputs": [],
154+
"source": []
130155
}
131156
],
132157
"metadata": {

src/plinder/core/loader/featurizer.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from typing import Any
23

34
import torch
@@ -13,8 +14,13 @@
1314

1415

1516
def system_featurizer(
16-
system: PlinderSystem, pad_value: int = -100, featurize_apo: bool = True
17+
system: PlinderSystem,
18+
pad_value: int = -100,
19+
featurize_apo: bool = True,
20+
seed: int = 42,
1721
) -> dict[str, Any]:
22+
# Set seed
23+
random.seed(seed)
1824
# Load holo and alternate (apo and pred) structures
1925
holo_structure = system.holo_structure
2026
apo_structures = system.alternate_structures
@@ -64,42 +70,39 @@ def system_featurizer(
6470
"holo_protein_coordinates_stacked": holo_protein_coordinates_stacked,
6571
"holo_protein_calpha_coordinates_stacked": holo_protein_calpha_coordinates_stacked,
6672
}
67-
68-
# Apo features
69-
# Note: Since there are multiple apo/pred structures linked to a holo structure, we will
70-
# stack all of them and let users choose how the want to use it
7173
if featurize_apo:
72-
all_apo_protein_atom_arrays = []
73-
all_apo_sequence_atom_masks_stacked = []
74-
all_apo_input_sequence_residue_masks_stacked = []
75-
all_apo_protein_coordinates_arrays_stacked = []
76-
all_apo_protein_calpha_coordinates_arrays_stacked = []
77-
for apo_id, apo_structure in apo_structures.items():
78-
apo_structure.set_chain(holo_chain)
79-
all_apo_protein_atom_arrays.append(apo_structure.protein_atom_array)
80-
all_apo_sequence_atom_masks_stacked.append(apo_structure.sequence_atom_mask)
81-
all_apo_input_sequence_residue_masks_stacked.append(
82-
apo_structure.input_sequence_residue_mask_stacked
83-
)
84-
all_apo_protein_coordinates_arrays_stacked.append(
85-
apo_structure.protein_coords
86-
)
87-
all_apo_protein_calpha_coordinates_arrays_stacked.append(
88-
apo_structure.protein_calpha_coords
89-
)
90-
all_apo_features = {
91-
"all_apo_sequence_atom_masks_stacked": all_apo_sequence_atom_masks_stacked,
92-
"all_apo_input_sequence_residue_masks_stacked": all_apo_input_sequence_residue_masks_stacked,
93-
"all_apo_protein_coordinates_arrays_stacked": all_apo_protein_coordinates_arrays_stacked,
94-
"all_apo_protein_calpha_coordinates_arrays_stacked": all_apo_protein_calpha_coordinates_arrays_stacked,
74+
selected_apo = apo_structures[random.choice(list(apo_structures.keys()))]
75+
# Set apo chain to match holo
76+
selected_apo.set_chain(holo_chain)
77+
apo_sequence_atom_mask_stacked = selected_apo.sequence_atom_mask
78+
apo_input_sequence_residue_mask_stacked = (
79+
selected_apo.input_sequence_residue_mask_stacked
80+
)
81+
apo_protein_coordinates_array_stacked = selected_apo.protein_coords
82+
apo_protein_calpha_coordinates_array_stacked = (
83+
selected_apo.protein_calpha_coords
84+
)
85+
# Apo to holo cropping mask
86+
apo_sequence_to_holo_structure_mask_stacked = (
87+
holo_structure.protein_structure_residue_mask(selected_apo)
88+
)
89+
90+
apo_features = {
91+
"apo_sequence_atom_masks_stacked": apo_sequence_atom_mask_stacked,
92+
"apo_input_sequence_residue_masks_stacked": apo_input_sequence_residue_mask_stacked,
93+
"apo_protein_coordinates_arrays_stacked": apo_protein_coordinates_array_stacked,
94+
"apo_protein_calpha_coordinates_arrays_stacked": apo_protein_calpha_coordinates_array_stacked,
95+
"apo_sequence_to_holo_structure_mask_stacked": apo_sequence_to_holo_structure_mask_stacked,
9596
}
9697
else:
97-
all_apo_features = {
98-
"all_apo_sequence_atom_masks_stacked": [],
99-
"all_apo_input_sequence_residue_masks_stacked": [],
100-
"all_apo_protein_coordinates_arrays_stacked": [],
101-
"all_apo_protein_calpha_coordinates_arrays_stacked": [],
98+
apo_features = {
99+
"apo_sequence_atom_masks_stacked": [],
100+
"apo_input_sequence_residue_masks_stacked": [],
101+
"apo_protein_coordinates_arrays_stacked": [],
102+
"apo_protein_calpha_coordinates_arrays_stacked": [],
103+
"apo_sequence_to_holo_structure_mask_stacked": [],
102104
}
105+
103106
# Ligand features
104107
input_ligand_templates = (
105108
holo_structure.input_ligand_templates
@@ -140,7 +143,7 @@ def system_featurizer(
140143
features = {
141144
"sequence_features": sequence_features,
142145
"holo_features": holo_features,
143-
"apo_features": all_apo_features,
146+
"apo_features": apo_features,
144147
"ligand_features": ligand_features,
145148
}
146149

src/plinder/core/loader/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def collate_complex(
127127
collated_properties, dim=0, value=pad_value
128128
)
129129
all_collated_and_padded_properties[feat_group] = collated_and_padded_properties
130-
return collated_and_padded_properties
130+
return all_collated_and_padded_properties
131131

132132

133133
def collate_batch(

src/plinder/core/structure/structure.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,20 @@ def input_sequence_residue_mask_stacked(self) -> list[list[int]]:
454454
)
455455
return [seqres_masks[ch] for ch in self.protein_chain_ordered]
456456

457+
@property
458+
def protein_structure_residue_mask(self, other: Structure) -> list[list[int]]:
459+
"""Mask residues from a given structure to another structure"""
460+
self_protein_atom_array = self.protein_atom_array
461+
other_protein_sequence_from_structure = other.protein_sequence_from_structure
462+
other_chain = other_protein_sequence_from_structure.protein_chains[0]
463+
assert self_protein_atom_array is not None
464+
assert other_protein_sequence_from_structure is not None
465+
other_sequence_dict = {other_chain: other_protein_sequence_from_structure}
466+
seqres_masks = get_residue_index_mapping_mask(
467+
other_sequence_dict, self_protein_atom_array
468+
)
469+
return [seqres_masks[ch] for ch in self.protein_chain_ordered]
470+
457471
@property
458472
def input_sequence_list_ordered_by_chain(self) -> list[str] | None:
459473
"""List of protein chains ordered the way it is in structure."""

0 commit comments

Comments
 (0)