Skip to content

Commit 5445147

Browse files
authored
Ensure no null BatchedAtomInputs are passed to the Trainer (#214)
* Update trainer.py * Update inputs.py * Update trainer.py * Update trainer.py * Update inputs.py * Update inputs.py * Update inputs.py * Update utils.py * Update pyproject.toml
1 parent ee60fd1 commit 5445147

File tree

4 files changed

+66
-21
lines changed

4 files changed

+66
-21
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from io import StringIO
1414
from itertools import groupby
1515
from pathlib import Path
16+
from retrying import retry
1617
from typing import Any, Callable, Dict, List, Literal, Set, Tuple, Type
1718

1819
import einx
@@ -74,7 +75,7 @@
7475
remove_consecutive_duplicate,
7576
)
7677
from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck
77-
from alphafold3_pytorch.utils.utils import default, exists, first
78+
from alphafold3_pytorch.utils.utils import default, exists, first, not_exists
7879

7980
# silence RDKit's warnings
8081

@@ -3303,7 +3304,7 @@ def pdb_input_to_molecule_input(
33033304

33043305
# datasets
33053306

3306-
# PDB dataset that returns a PDBInput based on folder
3307+
# PDB dataset that returns either a PDBInput or AtomInput based on folder
33073308

33083309

33093310
class PDBDataset(Dataset):
@@ -3321,6 +3322,7 @@ def __init__(
33213322
crop_size: int = 384,
33223323
training: bool | None = None, # extra training flag placed by Alex on PDBInput
33233324
sample_only_pdb_ids: Set[str] | None = None,
3325+
return_atom_inputs: bool = False,
33243326
**pdb_input_kwargs,
33253327
):
33263328
if isinstance(folder, str):
@@ -3333,6 +3335,7 @@ def __init__(
33333335
self.sample_type = sample_type
33343336
self.training = training
33353337
self.sample_only_pdb_ids = sample_only_pdb_ids
3338+
self.return_atom_inputs = return_atom_inputs
33363339
self.pdb_input_kwargs = pdb_input_kwargs
33373340

33383341
self.cropping_config = {
@@ -3369,8 +3372,8 @@ def __len__(self):
33693372
"""Return the number of PDB mmCIF files in the dataset."""
33703373
return len(self.files)
33713374

3372-
def __getitem__(self, idx: int | str) -> PDBInput:
3373-
"""Return a PDBInput object for the specified index."""
3375+
def get_item(self, idx: int | str) -> PDBInput | AtomInput | None:
3376+
"""Return either a PDBInput or an AtomInput object for the specified index."""
33743377
sampled_id = None
33753378

33763379
if exists(self.sampler):
@@ -3412,15 +3415,24 @@ def __getitem__(self, idx: int | str) -> PDBInput:
34123415
if self.training:
34133416
cropping_config = self.cropping_config
34143417

3415-
pdb_input = PDBInput(
3418+
i = PDBInput(
34163419
mmcif_filepath=str(mmcif_filepath),
34173420
chains=(chain_id_1, chain_id_2),
34183421
cropping_config=cropping_config,
34193422
training=self.training,
34203423
**self.pdb_input_kwargs,
34213424
)
34223425

3423-
return pdb_input
3426+
if self.return_atom_inputs:
3427+
i = maybe_transform_to_atom_input(i)
3428+
3429+
return i
3430+
3431+
def __getitem__(self, idx: int | str, max_attempts: int = 5) -> PDBInput | AtomInput:
3432+
"""Return either a PDBInput or an AtomInput object for the specified index."""
3433+
retry_decorator = retry(retry_on_result=not_exists, stop_max_attempt_number=max_attempts)
3434+
i = retry_decorator(self.get_item)(idx)
3435+
return i
34243436

34253437

34263438
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
@@ -3461,7 +3473,7 @@ def maybe_transform_to_atom_input(i: Any, raise_exception: bool = False) -> Atom
34613473

34623474
if not exists(maybe_to_atom_fn):
34633475
raise TypeError(
3464-
f"invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly"
3476+
f"Invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly"
34653477
)
34663478

34673479
try:

alphafold3_pytorch/trainer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def collate_inputs_to_batched_atom_input(
111111
inputs: List,
112112
int_pad_value = -1,
113113
atoms_per_window: int | None = None,
114-
map_input_fn: Callable | None = None
115-
114+
map_input_fn: Callable | None = None,
115+
transform_to_atom_inputs: bool = True,
116116
) -> BatchedAtomInput:
117117

118118
if exists(map_input_fn):
@@ -121,7 +121,25 @@ def collate_inputs_to_batched_atom_input(
121121
# go through all the inputs
122122
# and for any that is not AtomInput, try to transform it with the registered input type to corresponding registered function
123123

124-
atom_inputs = maybe_transform_to_atom_inputs(inputs)
124+
if transform_to_atom_inputs:
125+
atom_inputs = maybe_transform_to_atom_inputs(inputs)
126+
127+
if len(atom_inputs) < len(inputs):
128+
# if some of the `inputs` could not be converted into `atom_inputs`,
129+
# randomly select a subset of the `atom_inputs` to duplicate to match
130+
# the expected number of `atom_inputs`
131+
assert (
132+
len(atom_inputs) > 0
133+
), "No `AtomInput` objects could be created for the current batch."
134+
atom_inputs = random.choices(atom_inputs, k=len(inputs)) # nosec
135+
else:
136+
atom_inputs = inputs
137+
138+
assert all(isinstance(i, AtomInput) for i in atom_inputs), (
139+
"All inputs must be of type `AtomInput`. "
140+
"If you want to transform the inputs to `AtomInput`, "
141+
"set `transform_to_atom_inputs=True`."
142+
)
125143

126144
# take care of windowing the atompair_inputs and atompair_ids if they are not windowed already
127145

@@ -248,9 +266,14 @@ def DataLoader(
248266
*args,
249267
atoms_per_window: int | None = None,
250268
map_input_fn: Callable | None = None,
269+
transform_to_atom_inputs: bool = True,
251270
**kwargs
252271
):
253-
collate_fn = partial(collate_inputs_to_batched_atom_input, atoms_per_window = atoms_per_window)
272+
collate_fn = partial(
273+
collate_inputs_to_batched_atom_input,
274+
atoms_per_window = atoms_per_window,
275+
transform_to_atom_inputs = transform_to_atom_inputs,
276+
)
254277

255278
if exists(map_input_fn):
256279
collate_fn = partial(collate_fn, map_input_fn = map_input_fn)

alphafold3_pytorch/utils/utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
import numpy as np
22

3-
from typing import Any, List
4-
5-
def first(arr: List) -> Any:
6-
"""
7-
Returns first element of list
8-
9-
:param arr: the list
10-
:return: the element
11-
"""
12-
return arr[0]
3+
from typing import Any, Iterable, List
134

145

156
def exists(val: Any) -> bool:
@@ -21,6 +12,15 @@ def exists(val: Any) -> bool:
2112
return val is not None
2213

2314

15+
def not_exists(val: Any) -> bool:
16+
"""Check if a value does not exist.
17+
18+
:param val: The value to check.
19+
:return: `True` if the value does not exist, otherwise `False`.
20+
"""
21+
return val is None
22+
23+
2424
def default(v: Any, d: Any) -> Any:
2525
"""Return default value `d` if `v` does not exist (i.e., is `None`).
2626
@@ -31,6 +31,15 @@ def default(v: Any, d: Any) -> Any:
3131
return v if exists(v) else d
3232

3333

34+
def first(arr: Iterable[Any]) -> Any:
35+
"""Return the first element of an iterable object such as a list.
36+
37+
:param arr: An iterable object.
38+
:return: The first element of the iterable object.
39+
"""
40+
return arr[0]
41+
42+
3443
def always(value):
3544
"""Always return a value."""
3645

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"pydantic>=2.8.2",
4848
"pyyaml",
4949
"rdkit>=2023.9.6",
50+
"retrying",
5051
"scikit-learn>=1.5.0",
5152
"sh>=2.0.7",
5253
"shortuuid",

0 commit comments

Comments
 (0)