|
41 | 41 | from torch.nn.utils.rnn import pad_sequence |
42 | 42 | from torch.utils.data import Dataset |
43 | 43 |
|
44 | | -from alphafold3_pytorch.common import amino_acid_constants, dna_constants, rna_constants |
| 44 | +from alphafold3_pytorch.common import ( |
| 45 | + amino_acid_constants, |
| 46 | + dna_constants, |
| 47 | + rna_constants |
| 48 | +) |
45 | 49 | from alphafold3_pytorch.common.biomolecule import ( |
46 | 50 | Biomolecule, |
47 | 51 | _from_mmcif_object, |
|
86 | 90 | from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck |
87 | 91 | from alphafold3_pytorch.utils.utils import default, exists, first, not_exists |
88 | 92 |
|
| 93 | +from alphafold3_pytorch.attention import ( |
| 94 | + full_pairwise_repr_to_windowed, |
| 95 | + full_attn_bias_to_windowed, |
| 96 | + pad_at_dim, |
| 97 | + pad_or_slice_to |
| 98 | +) |
| 99 | + |
89 | 100 | # silence RDKit's warnings |
90 | 101 |
|
91 | 102 | RDLogger.DisableLog("rdApp.*") |
@@ -1821,6 +1832,17 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]: |
1821 | 1832 |
|
1822 | 1833 | return molecule_input |
1823 | 1834 |
|
| 1835 | +@typecheck |
| 1836 | +def alphafold3_inputs_to_batched_atom_input( |
| 1837 | + inp: Alphafold3Input | List[Alphafold3Input], |
| 1838 | + **collate_kwargs |
| 1839 | +) -> BatchedAtomInput: |
| 1840 | + |
| 1841 | + if isinstance(inp, Alphafold3Input): |
| 1842 | + inp = [inp] |
| 1843 | + |
| 1844 | + atom_inputs = maybe_transform_to_atom_inputs(inp) |
| 1845 | + return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs) |
1824 | 1846 |
|
1825 | 1847 | # pdb input |
1826 | 1848 |
|
@@ -3366,6 +3388,17 @@ def pdb_input_to_molecule_input( |
3366 | 3388 |
|
3367 | 3389 | return molecule_input |
3368 | 3390 |
|
| 3391 | +@typecheck |
| 3392 | +def pdb_inputs_to_batched_atom_input( |
| 3393 | + inp: PDBInput | List[PDBInput], |
| 3394 | + **collate_kwargs |
| 3395 | +) -> BatchedAtomInput: |
| 3396 | + |
| 3397 | + if isinstance(inp, PDBInput): |
| 3398 | + inp = [inp] |
| 3399 | + |
| 3400 | + atom_inputs = maybe_transform_to_atom_inputs(inp) |
| 3401 | + return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs) |
3369 | 3402 |
|
3370 | 3403 | # datasets |
3371 | 3404 |
|
@@ -3517,6 +3550,142 @@ def __getitem__(self, idx: int | str, max_attempts: int = 10) -> PDBInput | Atom |
3517 | 3550 |
|
3518 | 3551 | return i |
3519 | 3552 |
|
| 3553 | +# collation function |
| 3554 | + |
| 3555 | +@typecheck |
| 3556 | +def collate_inputs_to_batched_atom_input( |
| 3557 | + inputs: List, |
| 3558 | + int_pad_value = -1, |
| 3559 | + atoms_per_window: int | None = None, |
| 3560 | + map_input_fn: Callable | None = None, |
| 3561 | + transform_to_atom_inputs: bool = True, |
| 3562 | +) -> BatchedAtomInput: |
| 3563 | + |
| 3564 | + if exists(map_input_fn): |
| 3565 | + inputs = [map_input_fn(i) for i in inputs] |
| 3566 | + |
| 3567 | + # go through all the inputs |
| 3568 | + # and for any that is not AtomInput, try to transform it with the registered input type to corresponding registered function |
| 3569 | + |
| 3570 | + if transform_to_atom_inputs: |
| 3571 | + atom_inputs = maybe_transform_to_atom_inputs(inputs) |
| 3572 | + |
| 3573 | + if len(atom_inputs) < len(inputs): |
| 3574 | + # if some of the `inputs` could not be converted into `atom_inputs`, |
| 3575 | + # randomly select a subset of the `atom_inputs` to duplicate to match |
| 3576 | + # the expected number of `atom_inputs` |
| 3577 | + assert ( |
| 3578 | + len(atom_inputs) > 0 |
| 3579 | + ), "No `AtomInput` objects could be created for the current batch." |
| 3580 | + atom_inputs = random.choices(atom_inputs, k=len(inputs)) # nosec |
| 3581 | + else: |
| 3582 | + assert all(isinstance(i, AtomInput) for i in inputs), ( |
| 3583 | + "When `transform_to_atom_inputs=False`, all provided " |
| 3584 | + "inputs must be of type `AtomInput`." |
| 3585 | + ) |
| 3586 | + atom_inputs = inputs |
| 3587 | + |
| 3588 | + assert all(isinstance(i, AtomInput) for i in atom_inputs), ( |
| 3589 | + "All inputs must be of type `AtomInput`. " |
| 3590 | + "If you want to transform the inputs to `AtomInput`, " |
| 3591 | + "set `transform_to_atom_inputs=True`." |
| 3592 | + ) |
| 3593 | + |
| 3594 | + # take care of windowing the atompair_inputs and atompair_ids if they are not windowed already |
| 3595 | + |
| 3596 | + if exists(atoms_per_window): |
| 3597 | + for atom_input in atom_inputs: |
| 3598 | + atompair_inputs = atom_input.atompair_inputs |
| 3599 | + atompair_ids = atom_input.atompair_ids |
| 3600 | + |
| 3601 | + atompair_inputs_is_windowed = atompair_inputs.ndim == 4 |
| 3602 | + |
| 3603 | + if not atompair_inputs_is_windowed: |
| 3604 | + atom_input.atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window) |
| 3605 | + |
| 3606 | + if exists(atompair_ids): |
| 3607 | + atompair_ids_is_windowed = atompair_ids.ndim == 3 |
| 3608 | + |
| 3609 | + if not atompair_ids_is_windowed: |
| 3610 | + atom_input.atompair_ids = full_attn_bias_to_windowed(atompair_ids, window_size = atoms_per_window) |
| 3611 | + |
| 3612 | + # separate input dictionary into keys and values |
| 3613 | + |
| 3614 | + keys = list(atom_inputs[0].dict().keys()) |
| 3615 | + atom_inputs = [i.dict().values() for i in atom_inputs] |
| 3616 | + |
| 3617 | + outputs = [] |
| 3618 | + |
| 3619 | + for key, grouped in zip(keys, zip(*atom_inputs)): |
| 3620 | + # if all None, just return None |
| 3621 | + |
| 3622 | + not_none_grouped = [*filter(exists, grouped)] |
| 3623 | + |
| 3624 | + if len(not_none_grouped) == 0: |
| 3625 | + outputs.append(None) |
| 3626 | + continue |
| 3627 | + |
| 3628 | + # collate lists for uncollatable fields |
| 3629 | + |
| 3630 | + if key in UNCOLLATABLE_ATOM_INPUT_FIELDS: |
| 3631 | + outputs.append(grouped) |
| 3632 | + continue |
| 3633 | + |
| 3634 | + # default to empty tensor for any Nones |
| 3635 | + |
| 3636 | + one_tensor = not_none_grouped[0] |
| 3637 | + |
| 3638 | + dtype = one_tensor.dtype |
| 3639 | + ndim = one_tensor.ndim |
| 3640 | + |
| 3641 | + # use -1 for padding int values, for assuming int are labels - if not, handle within alphafold3 |
| 3642 | + |
| 3643 | + if key in ATOM_DEFAULT_PAD_VALUES: |
| 3644 | + pad_value = ATOM_DEFAULT_PAD_VALUES[key] |
| 3645 | + elif dtype in (torch.int, torch.long): |
| 3646 | + pad_value = int_pad_value |
| 3647 | + elif dtype == torch.bool: |
| 3648 | + pad_value = False |
| 3649 | + else: |
| 3650 | + pad_value = 0. |
| 3651 | + |
| 3652 | + # get the max lengths across all dimensions |
| 3653 | + |
| 3654 | + shapes_as_tensor = torch.stack([tensor(tuple(g.shape) if exists(g) else ((0,) * ndim)).int() for g in grouped], dim = -1) |
| 3655 | + |
| 3656 | + max_lengths = shapes_as_tensor.amax(dim = -1) |
| 3657 | + |
| 3658 | + default_tensor = torch.full(max_lengths.tolist(), pad_value, dtype = dtype) |
| 3659 | + |
| 3660 | + # pad across all dimensions |
| 3661 | + |
| 3662 | + padded_inputs = [] |
| 3663 | + |
| 3664 | + for inp in grouped: |
| 3665 | + |
| 3666 | + if not exists(inp): |
| 3667 | + padded_inputs.append(default_tensor) |
| 3668 | + continue |
| 3669 | + |
| 3670 | + for dim, max_length in enumerate(max_lengths.tolist()): |
| 3671 | + inp = pad_at_dim(inp, (0, max_length - inp.shape[dim]), value = pad_value, dim = dim) |
| 3672 | + |
| 3673 | + padded_inputs.append(inp) |
| 3674 | + |
| 3675 | + # stack |
| 3676 | + |
| 3677 | + stacked = torch.stack(padded_inputs) |
| 3678 | + |
| 3679 | + outputs.append(stacked) |
| 3680 | + |
| 3681 | + # batched atom input dictionary |
| 3682 | + |
| 3683 | + batched_atom_input_dict = dict(tuple(zip(keys, outputs))) |
| 3684 | + |
| 3685 | + # reconstitute dictionary |
| 3686 | + |
| 3687 | + batched_atom_inputs = BatchedAtomInput(**batched_atom_input_dict) |
| 3688 | + return batched_atom_inputs |
3520 | 3689 |
|
3521 | 3690 | # the config used for keeping track of all the disparate inputs and their transforms down to AtomInput |
3522 | 3691 | # this can be preprocessed or will be taken care of automatically within the Trainer during data collation |
|
0 commit comments