Skip to content

Commit aaee922

Browse files
committed
fix collation issue with output_atompos_indices
1 parent 501aaa4 commit aaee922

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
from alphafold3_pytorch.alphafold3 import Alphafold3
8-
from alphafold3_pytorch.attention import pad_at_dim
8+
from alphafold3_pytorch.attention import pad_at_dim, pad_or_slice_to
99

1010
from typing import TypedDict, List, Callable
1111

@@ -169,9 +169,21 @@ def collate_inputs_to_batched_atom_input(
169169

170170
outputs.append(stacked)
171171

172+
# batched atom input dictionary
173+
174+
batched_atom_input_dict = dict(tuple(zip(keys, outputs)))
175+
176+
# just ensure output_atompos_indices has full atom_seq_len manually for now
177+
178+
output_atompos_indices = batched_atom_input_dict.get('output_atompos_indices', None)
179+
180+
if exists(output_atompos_indices):
181+
atom_seq_len = batched_atom_input_dict['atom_inputs'].shape[-2]
182+
batched_atom_input_dict.update(output_atompos_indices = pad_or_slice_to(output_atompos_indices, atom_seq_len, dim = -1, pad_value = -1))
183+
172184
# reconstitute dictionary
173185

174-
batched_atom_inputs = BatchedAtomInput(**dict(tuple(zip(keys, outputs))))
186+
batched_atom_inputs = BatchedAtomInput(**batched_atom_input_dict)
175187
return batched_atom_inputs
176188

177189
@typecheck

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.91"
3+
version = "0.1.92"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def test_collate_fn():
225225
),
226226
)
227227

228-
dataset = MockAtomDataset(1)
228+
dataset = MockAtomDataset(5)
229229

230-
batched_atom_inputs = collate_inputs_to_batched_atom_input([dataset[0]])
230+
batched_atom_inputs = collate_inputs_to_batched_atom_input([dataset[i] for i in range(3)])
231231

232232
_, breakdown = alphafold3(**asdict(batched_atom_inputs), return_loss_breakdown = True)
233233

0 commit comments

Comments
 (0)