Skip to content

Commit ca644de

Browse files
committed
simple function for extending input transform config
1 parent 4059aba commit ca644de

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
Alphafold3
3333
)
3434

35+
from alphafold3_pytorch.inputs import (
36+
register_input_transform
37+
)
38+
3539
from alphafold3_pytorch.trainer import (
3640
Trainer,
3741
DataLoader,

alphafold3_pytorch/inputs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypedDict, Literal
1+
from typing import Type, TypedDict, Literal, Callable
22

33
from alphafold3_pytorch.typing import (
44
typecheck,
@@ -83,3 +83,12 @@ def single_protein_input_and_single_nucleic_acid_to_atom_input(
8383
SingleProteinInput: single_protein_input_to_atom_input,
8484
SingleProteinSingleNucleicAcidInput: single_protein_input_and_single_nucleic_acid_to_atom_input
8585
}
86+
87+
# function for extending the config
88+
89+
@typecheck
90+
def register_input_transform(
91+
input_type: Type,
92+
fn: Callable[[TypedDict], AtomInput]
93+
):
94+
INPUT_TO_ATOM_TRANSFORM[input_type] = fn

alphafold3_pytorch/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def collate_af3_inputs(
9393
if not exists(maybe_to_atom_fn):
9494
raise TypeError(f'invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly')
9595

96-
atom_inputs = maybe_to_atom_fn(i)
96+
atom_inputs.append(maybe_to_atom_fn(i))
9797

9898
# separate input dictionary into keys and values
9999

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.45"
3+
version = "0.1.47"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)