|
139 | 139 |
|
140 | 140 | # functions |
141 | 141 |
|
| 142 | +def l2norm(t): |
| 143 | + return F.normalize(t, dim = -1) |
| 144 | + |
142 | 145 | def flatten(arr): |
143 | 146 | return [el for sub_arr in arr for el in sub_arr] |
144 | 147 |
|
@@ -199,6 +202,96 @@ def hard_validate_atom_indices_ascending( |
199 | 202 |
|
200 | 203 | assert (difference >= 0).all(), f'detected invalid {error_msg_field} for in a batch: {present_indices}' |
201 | 204 |
|
| 205 | +# functions for deriving the frames for ligands |
| 206 | +# this follows the logic from Alphafold3 Supplementary section 4.3.2 |
| 207 | + |
| 208 | +@typecheck |
| 209 | +def get_indices_three_closest_atom_pos( |
| 210 | + atom_pos: Float['... n d'], |
| 211 | +) -> Int['... 3']: |
| 212 | + |
| 213 | + prec_dims, device = atom_pos.shape[:-2], atom_pos.device |
| 214 | + num_atoms, has_batch = atom_pos.shape[-2], atom_pos.ndim == 3 |
| 215 | + |
| 216 | + if num_atoms < 3: |
| 217 | + return atom_pos.new_full((*prec_dims, 3), -1).long() |
| 218 | + |
| 219 | + if not has_batch: |
| 220 | + atom_pos = rearrange(atom_pos, '... -> 1 ...') |
| 221 | + |
| 222 | + atom_dist = torch.cdist(atom_pos, atom_pos) |
| 223 | + |
| 224 | + # mask out the distance to self |
| 225 | + |
| 226 | + eye = torch.eye(num_atoms, device = device, dtype = torch.bool) |
| 227 | + |
| 228 | + atom_dist.masked_fill_(eye, 1e4) |
| 229 | + |
| 230 | + # will use topk on the negative of the distance |
| 231 | + |
| 232 | + neg_distance, two_closest_atom_indices = (-atom_dist).topk(2, dim = -1) |
| 233 | + |
| 234 | + mean_neg_distance = neg_distance.mean(dim = -1) |
| 235 | + |
| 236 | + best_atom_pair_index = mean_neg_distance.argmax(dim = -1) |
| 237 | + |
| 238 | + best_two_atom_neighbors = einx.get_at('... [m] c, ... -> ... c', two_closest_atom_indices, best_atom_pair_index) |
| 239 | + |
| 240 | + # place the chosen atom at the center |
| 241 | + |
| 242 | + three_atom_indices, _ = pack(( |
| 243 | + best_two_atom_neighbors[..., 0], |
| 244 | + best_atom_pair_index, |
| 245 | + best_two_atom_neighbors[..., 1], |
| 246 | + ), 'b *') |
| 247 | + |
| 248 | + if not has_batch: |
| 249 | + three_atom_indices = rearrange(three_atom_indices, '1 ... -> ...') |
| 250 | + |
| 251 | + return three_atom_indices |
| 252 | + |
| 253 | +@typecheck |
| 254 | +def get_angle_between_edges( |
| 255 | + edge1: Float['... 3'], |
| 256 | + edge2: Float['... 3'] |
| 257 | +) -> Float['...']: |
| 258 | + cos = torch.dot(l2norm(edge1), l2norm(edge2)) |
| 259 | + return torch.acos(cos) |
| 260 | + |
| 261 | +@typecheck |
| 262 | +def get_frames_from_atom_pos( |
| 263 | + atom_pos: Float['... n d'], |
| 264 | + filter_colinear_pos: bool = False, |
| 265 | + is_colinear_angle_thres: float = 25. # they use 25 degrees as a way of filtering out invalid frames |
| 266 | +) -> Int['... 3']: |
| 267 | + |
| 268 | + frames = get_indices_three_closest_atom_pos(atom_pos) |
| 269 | + |
| 270 | + if not filter_colinear_pos: |
| 271 | + return frames |
| 272 | + |
| 273 | + # get the edges and derive angles |
| 274 | + |
| 275 | + three_atom_pos = einx.get_at('... [m] c, ... three -> ... three c', atom_pos, frames) |
| 276 | + |
| 277 | + left_pos, center_pos, right_pos = three_atom_pos.unbind(dim = -2) |
| 278 | + |
| 279 | + edges1, edges2 = (left_pos - center_pos), (right_pos - center_pos) |
| 280 | + |
| 281 | + angle = get_angle_between_edges(edges1, edges2) |
| 282 | + |
| 283 | + degree = torch.rad2deg(angle) |
| 284 | + |
| 285 | + is_colinear = ( |
| 286 | + (degree.abs() < is_colinear_angle_thres) | |
| 287 | + ((180. - degree.abs()) < is_colinear_angle_thres) |
| 288 | + ) |
| 289 | + |
| 290 | + # set any three atoms that are colinear to -1 indices |
| 291 | + |
| 292 | + three_atom_indices = einx.where('..., ... three, -> ... three', ~is_colinear, frames, -1) |
| 293 | + return three_atom_indices |
| 294 | + |
202 | 295 | # atom level, what Alphafold3 accepts |
203 | 296 |
|
204 | 297 | UNCOLLATABLE_ATOM_INPUT_FIELDS = {'filepath'} |
|
0 commit comments