Skip to content

Commit 20c3006

Browse files
committed
first offer way to create Alphafold3 instance from config file
1 parent 2e7109b commit 20c3006

File tree

4 files changed

+91
-8
lines changed

4 files changed

+91
-8
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
AtomInput
3939
)
4040

41+
from alphafold3_pytorch.configs import (
42+
Alphafold3Config
43+
)
44+
4145
__all__ = [
4246
Attention,
4347
Attend,

alphafold3_pytorch/alphafold3.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,7 +2694,7 @@ def __init__(
26942694
self,
26952695
*,
26962696
dim_single_inputs,
2697-
atompair_dist_bins: Float[' d'],
2697+
atompair_dist_bins: List[float],
26982698
dim_single = 384,
26992699
dim_pairwise = 128,
27002700
num_plddt_bins = 50,
@@ -2705,6 +2705,8 @@ def __init__(
27052705
):
27062706
super().__init__()
27072707

2708+
atompair_dist_bins = Tensor(atompair_dist_bins)
2709+
27082710
self.register_buffer('atompair_dist_bins', atompair_dist_bins)
27092711

27102712
num_dist_bins = atompair_dist_bins.shape[-1]
@@ -2828,7 +2830,7 @@ def __init__(
28282830
dim_single = 384,
28292831
dim_pairwise = 128,
28302832
dim_token = 768,
2831-
distance_bins: Float[' dist_bins'] = torch.linspace(3, 20, 38),
2833+
distance_bins: List[float] = torch.linspace(3, 20, 38).tolist(),
28322834
ignore_index = -1,
28332835
num_dist_bins: int | None = None,
28342836
num_plddt_bins = 50,
@@ -3018,6 +3020,8 @@ def __init__(
30183020

30193021
# logit heads
30203022

3023+
distance_bins = Tensor(distance_bins)
3024+
30213025
self.register_buffer('distance_bins', distance_bins)
30223026
num_dist_bins = default(num_dist_bins, len(distance_bins))
30233027

@@ -3065,7 +3069,7 @@ def save(self, path: str | Path, overwrite = False):
30653069
if isinstance(path, str):
30663070
path = Path(path)
30673071

3068-
assert not path.is_dir() and (not path.exists() or overwrite)
3072+
assert path.is_file() and (not path.exists() or overwrite)
30693073

30703074
path.parent.mkdir(exist_ok = True, parents = True)
30713075

@@ -3080,7 +3084,7 @@ def load(self, path: str | Path, strict = False):
30803084
if isinstance(path, str):
30813085
path = Path(path)
30823086

3083-
assert path.exists() and not path.is_dir()
3087+
assert path.exists() and path.is_file()
30843088

30853089
package = torch.load(str(path), map_location = 'cpu')
30863090

@@ -3098,7 +3102,7 @@ def init_and_load(path: str | Path):
30983102
if isinstance(path, str):
30993103
path = Path(path)
31003104

3101-
assert path.exists() and not path.is_dir()
3105+
assert path.is_file()
31023106

31033107
package = torch.load(str(path), map_location = 'cpu')
31043108

alphafold3_pytorch/configs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
from alphafold3_pytorch.typing import typecheck
4+
from alphafold3_pytorch.alphafold3 import Alphafold3
5+
6+
import yaml
7+
from pathlib import Path
8+
9+
from pydantic import BaseModel
10+
11+
# functions
12+
13+
def exists(v):
14+
return v is not None
15+
16+
@typecheck
17+
def yaml_config_path_to_dict(
18+
path: str | Path
19+
) -> dict | None:
20+
21+
if isinstance(path, str):
22+
path = Path(path)
23+
24+
assert path.is_file()
25+
26+
with open(str(path), 'r') as f:
27+
maybe_config_dict = yaml.safe_load(f)
28+
29+
assert exists(maybe_config_dict), f'unable to parse yaml config at {str(path)}'
30+
assert isinstance(maybe_config_dict, dict), f'yaml config file is not a dictionary'
31+
32+
return maybe_config_dict
33+
34+
# base pydantic classes for constructing alphafold3 and trainer from config files
35+
36+
class BaseModelWithExtra(BaseModel):
37+
class Config:
38+
extra = 'allow'
39+
use_enum_values = True
40+
41+
class Alphafold3Config(BaseModelWithExtra):
42+
dim_atom_inputs: int
43+
dim_template_feats: int
44+
dim_template_model: int
45+
atoms_per_window: int
46+
dim_atom: int
47+
dim_atompair_inputs: int
48+
dim_atompair: int
49+
dim_input_embedder_token: int
50+
dim_single: int
51+
dim_pairwise: int
52+
dim_token: int
53+
ignore_index: int = -1
54+
num_dist_bins: int | None
55+
num_plddt_bins: int
56+
num_pde_bins: int
57+
num_pae_bins: int
58+
sigma_data: int | float
59+
diffusion_num_augmentations: int
60+
loss_confidence_weight: int | float
61+
loss_distogram_weight: int | float
62+
loss_diffusion_weight: int | float
63+
64+
@staticmethod
65+
@typecheck
66+
def from_yaml_file(path: str | Path):
67+
config_dict = yaml_config_path_to_dict(path)
68+
return Alphafold3Config(**config_dict)
69+
70+
def create_instance(self) -> Alphafold3:
71+
alphafold3 = Alphafold3(**self.dict())
72+
return alphafold3
73+
74+
def create_instance_from_yaml_file(path: str | Path) -> Alphafold3:
75+
af3_config = Alphafold3Config.from_yaml_file(path)
76+
return af3_config.create_instance()

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.20"
3+
version = "0.1.21"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -30,13 +30,12 @@ dependencies = [
3030
"ema-pytorch>=0.4.8",
3131
"environs",
3232
"frame-averaging-pytorch>=0.0.18",
33-
"hydra-core",
3433
"jaxtyping>=0.2.28",
3534
"lightning>=2.2.5",
36-
"omegaconf",
3735
"pandas>=1.5.3",
3836
"pdbeccdutils>=0.8.5",
3937
"pydantic>=2.7.2",
38+
"pyyaml",
4039
"taylor-series-linear-attention>=0.1.9",
4140
"timeout_decorator>=0.5.0",
4241
'torch_geometric',

0 commit comments

Comments
 (0)