Skip to content

Commit e4c5b7d

Browse files
committed
export fns for yaml config -> Alphafold3 or Trainer
1 parent 2515ad2 commit e4c5b7d

File tree

5 files changed

+22
-8
lines changed

5 files changed

+22
-8
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040

4141
from alphafold3_pytorch.configs import (
4242
Alphafold3Config,
43-
TrainerConfig
43+
TrainerConfig,
44+
create_alphafold3_from_yaml,
45+
create_trainer_from_yaml
4446
)
4547

4648
__all__ = [
@@ -74,5 +76,7 @@
7476
Alphafold3Config,
7577
AtomInput,
7678
Trainer,
77-
TrainerConfig
79+
TrainerConfig,
80+
create_alphafold3_from_yaml,
81+
create_trainer_from_yaml
7882
]

alphafold3_pytorch/configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,8 @@ def create_instance_from_yaml_file(
142142

143143
trainer_config = TrainerConfig.from_yaml_file(path)
144144
return trainer_config.create_instance(**kwargs)
145+
146+
# convenience functions
147+
148+
create_alphafold3_from_yaml = Alphafold3Config.create_instance_from_yaml_file
149+
create_trainer_from_yaml = TrainerConfig.create_instance_from_yaml_file

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.24"
3+
version = "0.1.25"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
InputFeatureEmbedder,
2323
ConfidenceHead,
2424
DistogramHead,
25-
Alphafold3,
26-
Alphafold3Config
25+
Alphafold3
26+
)
27+
28+
from alphafold3_pytorch.configs import (
29+
Alphafold3Config,
30+
create_alphafold3_from_yaml
2731
)
2832

2933
from alphafold3_pytorch.alphafold3 import (
@@ -574,5 +578,5 @@ def test_alphafold3_config():
574578
curr_dir = Path(__file__).parents[0]
575579
af3_yaml = curr_dir / 'alphafold3.yaml'
576580

577-
alphafold3 = Alphafold3Config.create_instance_from_yaml_file(af3_yaml)
581+
alphafold3 = create_alphafold3_from_yaml(af3_yaml)
578582
assert isinstance(alphafold3, Alphafold3)

tests/test_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
AtomInput,
1414
DataLoader,
1515
Trainer,
16-
TrainerConfig
16+
TrainerConfig,
17+
create_trainer_from_yaml
1718
)
1819

1920
# mock dataset
@@ -173,7 +174,7 @@ def test_trainer_config():
173174
curr_dir = Path(__file__).parents[0]
174175
trainer_yaml_path = curr_dir / 'trainer.yaml'
175176

176-
trainer = TrainerConfig.create_instance_from_yaml_file(
177+
trainer = create_trainer_from_yaml(
177178
trainer_yaml_path,
178179
dataset = MockAtomDataset(16)
179180
)

0 commit comments

Comments
 (0)