Skip to content

Commit addf096

Browse files
committed
able to override default atom field padding values during collation, set molecule_atom_lens to 0
1 parent a07e176 commit addf096

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,15 @@ def inner(x, *args, **kwargs):
159159
# atom level, what Alphafold3 accepts
160160

161161
UNCOLLATABLE_ATOM_INPUT_FIELDS = {'filepath'}
162-
ATOM_INPUT_EXCLUDE_MODEL_FIELDS = {'filepath', 'chains'}
162+
163+
ATOM_INPUT_EXCLUDE_MODEL_FIELDS = {
164+
'filepath',
165+
'chains'
166+
}
167+
168+
ATOM_DEFAULT_PAD_VALUES = dict(
169+
molecule_atom_lens = 0
170+
)
163171

164172
@typecheck
165173
@dataclass

alphafold3_pytorch/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PDBInput,
2828
maybe_transform_to_atom_inputs,
2929
UNCOLLATABLE_ATOM_INPUT_FIELDS,
30+
ATOM_DEFAULT_PAD_VALUES,
3031
)
3132

3233
from alphafold3_pytorch.data import (
@@ -171,7 +172,9 @@ def collate_inputs_to_batched_atom_input(
171172

172173
# use -1 for padding int values, for assuming int are labels - if not, handle within alphafold3
173174

174-
if dtype in (torch.int, torch.long):
175+
if key in ATOM_DEFAULT_PAD_VALUES:
176+
pad_value = ATOM_DEFAULT_PAD_VALUES[key]
177+
elif dtype in (torch.int, torch.long):
175178
pad_value = int_pad_value
176179
elif dtype == torch.bool:
177180
pad_value = False

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.125"
3+
version = "0.2.126"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)