Skip to content

Commit 9520dc8

Browse files
committed
exclude filepath and chains from alphafold3 forward
1 parent 0a8757b commit 9520dc8

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4832,8 +4832,6 @@ def forward(
48324832
pde_labels: Int['b n n'] | Int['b m m'] | None = None,
48334833
plddt_labels: Int['b n'] | Int['b m'] | None = None,
48344834
resolved_labels: Int['b n'] | Int['b m'] | None = None,
4835-
chains: Int['b 2'] | None = None,
4836-
filepath: List[str] | None = None,
48374835
return_loss_breakdown = False,
48384836
return_loss: bool = None,
48394837
return_present_sampled_atoms: bool = False,

alphafold3_pytorch/inputs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@
125125
def flatten(arr):
126126
return [el for sub_arr in arr for el in sub_arr]
127127

128+
def without_keys(d: dict, exclude: set):
129+
return {k: v for k, v in d.items() if k not in exclude}
130+
128131
def pad_to_len(t, length, value = 0, dim = -1):
129132
assert dim < 0
130133
zeros = (0, 0) * (-dim - 1)
@@ -149,6 +152,9 @@ def inner(x, *args, **kwargs):
149152

150153
# atom level, what Alphafold3 accepts
151154

155+
UNCOLLATABLE_ATOM_INPUT_FIELDS = {'filepath'}
156+
ATOM_INPUT_EXCLUDE_MODEL_FIELDS = {'filepath', 'chains'}
157+
152158
@typecheck
153159
@dataclass
154160
class AtomInput:
@@ -217,6 +223,9 @@ class BatchedAtomInput:
217223
def dict(self):
218224
return asdict(self)
219225

226+
def model_forward_dict(self):
227+
return without_keys(self.dict(), ATOM_INPUT_EXCLUDE_MODEL_FIELDS)
228+
220229
# functions for saving an AtomInput to disk or loading from disk to AtomInput
221230

222231
@typecheck

alphafold3_pytorch/trainer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
BatchedAtomInput,
2727
Alphafold3Input,
2828
PDBInput,
29-
maybe_transform_to_atom_inputs
29+
maybe_transform_to_atom_inputs,
30+
UNCOLLATABLE_ATOM_INPUT_FIELDS,
3031
)
3132

3233
from alphafold3_pytorch.data import (
@@ -147,7 +148,7 @@ def collate_inputs_to_batched_atom_input(
147148

148149
outputs = []
149150

150-
for group_index, grouped in enumerate(zip(*atom_inputs)):
151+
for key, grouped in zip(keys, zip(*atom_inputs)):
151152
# if all None, just return None
152153

153154
not_none_grouped = [*filter(exists, grouped)]
@@ -158,8 +159,8 @@ def collate_inputs_to_batched_atom_input(
158159

159160
# collate list of input filepath strings
160161

161-
if keys[group_index] == "filepath":
162-
outputs.append(not_none_grouped)
162+
if key in UNCOLLATABLE_ATOM_INPUT_FIELDS:
163+
outputs.append(grouped)
163164
continue
164165

165166
# default to empty tensor for any Nones
@@ -640,7 +641,7 @@ def __call__(
640641
# model forwards
641642

642643
loss, loss_breakdown = self.model(
643-
**inputs.dict(),
644+
**inputs.model_forward_dict(),
644645
return_loss_breakdown = True
645646
)
646647

@@ -702,7 +703,7 @@ def __call__(
702703

703704
for valid_batch in self.valid_dataloader:
704705
valid_loss, loss_breakdown = eval_model(
705-
**valid_batch.dict(),
706+
**valid_batch.model_forward_dict(),
706707
return_loss_breakdown = True
707708
)
708709

@@ -742,7 +743,7 @@ def __call__(
742743

743744
for test_batch in self.test_dataloader:
744745
test_loss, loss_breakdown = eval_model(
745-
**test_batch.dict(),
746+
**test_batch.model_forward_dict(),
746747
return_loss_breakdown = True
747748
)
748749

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.118"
3+
version = "0.2.119"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_alphafold3_input(directed_bonds):
115115
)
116116
)
117117

118-
alphafold3(**batched_atom_input.dict(), num_sample_steps = 1)
118+
alphafold3(**batched_atom_input.model_forward_dict(), num_sample_steps = 1)
119119

120120
def test_atompos_input():
121121

@@ -168,7 +168,7 @@ def test_atompos_input():
168168
)
169169
)
170170

171-
loss = alphafold3(**batched_atom_input.dict())
171+
loss = alphafold3(**batched_atom_input.model_forward_dict())
172172
loss.backward()
173173

174174
# sampling
@@ -244,7 +244,7 @@ def test_pdbinput_input():
244244
),
245245
)
246246

247-
loss = alphafold3(**batched_atom_input.dict())
247+
loss = alphafold3(**batched_atom_input.model_forward_dict())
248248
loss.backward()
249249

250250
# sampling

0 commit comments

Comments
 (0)