|
| 1 | +import random |
1 | 2 | from typing import Any |
2 | 3 |
|
3 | 4 | import torch |
|
13 | 14 |
|
14 | 15 |
|
15 | 16 | def system_featurizer( |
16 | | - system: PlinderSystem, pad_value: int = -100, featurize_apo: bool = True |
| 17 | + system: PlinderSystem, |
| 18 | + pad_value: int = -100, |
| 19 | + featurize_apo: bool = True, |
| 20 | + seed: int = 42, |
17 | 21 | ) -> dict[str, Any]: |
| 22 | + # Set seed |
| 23 | + random.seed(seed) |
18 | 24 | # Load holo and alternate (apo and pred) structures |
19 | 25 | holo_structure = system.holo_structure |
20 | 26 | apo_structures = system.alternate_structures |
@@ -64,42 +70,39 @@ def system_featurizer( |
64 | 70 | "holo_protein_coordinates_stacked": holo_protein_coordinates_stacked, |
65 | 71 | "holo_protein_calpha_coordinates_stacked": holo_protein_calpha_coordinates_stacked, |
66 | 72 | } |
67 | | - |
68 | | - # Apo features |
69 | | - # Note: Since there are multiple apo/pred structures linked to a holo structure, we will |
70 | | - # stack all of them and let users choose how the want to use it |
71 | 73 | if featurize_apo: |
72 | | - all_apo_protein_atom_arrays = [] |
73 | | - all_apo_sequence_atom_masks_stacked = [] |
74 | | - all_apo_input_sequence_residue_masks_stacked = [] |
75 | | - all_apo_protein_coordinates_arrays_stacked = [] |
76 | | - all_apo_protein_calpha_coordinates_arrays_stacked = [] |
77 | | - for apo_id, apo_structure in apo_structures.items(): |
78 | | - apo_structure.set_chain(holo_chain) |
79 | | - all_apo_protein_atom_arrays.append(apo_structure.protein_atom_array) |
80 | | - all_apo_sequence_atom_masks_stacked.append(apo_structure.sequence_atom_mask) |
81 | | - all_apo_input_sequence_residue_masks_stacked.append( |
82 | | - apo_structure.input_sequence_residue_mask_stacked |
83 | | - ) |
84 | | - all_apo_protein_coordinates_arrays_stacked.append( |
85 | | - apo_structure.protein_coords |
86 | | - ) |
87 | | - all_apo_protein_calpha_coordinates_arrays_stacked.append( |
88 | | - apo_structure.protein_calpha_coords |
89 | | - ) |
90 | | - all_apo_features = { |
91 | | - "all_apo_sequence_atom_masks_stacked": all_apo_sequence_atom_masks_stacked, |
92 | | - "all_apo_input_sequence_residue_masks_stacked": all_apo_input_sequence_residue_masks_stacked, |
93 | | - "all_apo_protein_coordinates_arrays_stacked": all_apo_protein_coordinates_arrays_stacked, |
94 | | - "all_apo_protein_calpha_coordinates_arrays_stacked": all_apo_protein_calpha_coordinates_arrays_stacked, |
| 74 | + selected_apo = apo_structures[random.choice(list(apo_structures.keys()))] |
| 75 | + # Set apo chain to match holo |
| 76 | + selected_apo.set_chain(holo_chain) |
| 77 | + apo_sequence_atom_mask_stacked = selected_apo.sequence_atom_mask |
| 78 | + apo_input_sequence_residue_mask_stacked = ( |
| 79 | + selected_apo.input_sequence_residue_mask_stacked |
| 80 | + ) |
| 81 | + apo_protein_coordinates_array_stacked = selected_apo.protein_coords |
| 82 | + apo_protein_calpha_coordinates_array_stacked = ( |
| 83 | + selected_apo.protein_calpha_coords |
| 84 | + ) |
| 85 | + # Apo to holo cropping mask |
| 86 | + apo_sequence_to_holo_structure_mask_stacked = ( |
| 87 | + holo_structure.protein_structure_residue_mask(selected_apo) |
| 88 | + ) |
| 89 | + |
| 90 | + apo_features = { |
| 91 | + "apo_sequence_atom_masks_stacked": apo_sequence_atom_mask_stacked, |
| 92 | + "apo_input_sequence_residue_masks_stacked": apo_input_sequence_residue_mask_stacked, |
| 93 | + "apo_protein_coordinates_arrays_stacked": apo_protein_coordinates_array_stacked, |
| 94 | + "apo_protein_calpha_coordinates_arrays_stacked": apo_protein_calpha_coordinates_array_stacked, |
| 95 | + "apo_sequence_to_holo_structure_mask_stacked": apo_sequence_to_holo_structure_mask_stacked, |
95 | 96 | } |
96 | 97 | else: |
97 | | - all_apo_features = { |
98 | | - "all_apo_sequence_atom_masks_stacked": [], |
99 | | - "all_apo_input_sequence_residue_masks_stacked": [], |
100 | | - "all_apo_protein_coordinates_arrays_stacked": [], |
101 | | - "all_apo_protein_calpha_coordinates_arrays_stacked": [], |
| 98 | + apo_features = { |
| 99 | + "apo_sequence_atom_masks_stacked": [], |
| 100 | + "apo_input_sequence_residue_masks_stacked": [], |
| 101 | + "apo_protein_coordinates_arrays_stacked": [], |
| 102 | + "apo_protein_calpha_coordinates_arrays_stacked": [], |
| 103 | + "apo_sequence_to_holo_structure_mask_stacked": [], |
102 | 104 | } |
| 105 | + |
103 | 106 | # Ligand features |
104 | 107 | input_ligand_templates = ( |
105 | 108 | holo_structure.input_ligand_templates |
@@ -140,7 +143,7 @@ def system_featurizer( |
140 | 143 | features = { |
141 | 144 | "sequence_features": sequence_features, |
142 | 145 | "holo_features": holo_features, |
143 | | - "apo_features": all_apo_features, |
| 146 | + "apo_features": apo_features, |
144 | 147 | "ligand_features": ligand_features, |
145 | 148 | } |
146 | 149 |
|
|
0 commit comments