11from __future__ import annotations
22
33from alphafold3_pytorch .typing import typecheck
4+ from typing import Callable , List
5+
46from alphafold3_pytorch .alphafold3 import Alphafold3
57
8+ from alphafold3_pytorch .trainer import (
9+ Trainer ,
10+ Dataset ,
11+ Fabric ,
12+ Optimizer ,
13+ LRScheduler
14+ )
15+
616import yaml
717from pathlib import Path
818
@@ -27,7 +37,7 @@ def yaml_config_path_to_dict(
2737 maybe_config_dict = yaml .safe_load (f )
2838
2939 assert exists (maybe_config_dict ), f'unable to parse yaml config at { str (path )} '
30- assert isinstance (maybe_config_dict , dict ), f 'yaml config file is not a dictionary'
40+ assert isinstance (maybe_config_dict , dict ), 'yaml config file is not a dictionary'
3141
3242 return maybe_config_dict
3343
@@ -76,4 +86,59 @@ def create_instance_from_yaml_file(path: str | Path) -> Alphafold3:
7686 return af3_config .create_instance ()
7787
7888class TrainerConfig (BaseModelWithExtra ):
79- pass
89+ model : Alphafold3Config
90+ num_train_steps : int
91+ batch_size : int
92+ grad_accum_every : int
93+ valid_every : int
94+ ema_decay : float
95+ lr : float
96+ clip_grad_norm : int | float
97+ accelerator : str
98+ checkpoint_prefix : str
99+ checkpoint_every : int
100+ checkpoint_folder : str
101+ overwrite_checkpoints : bool
102+
103+ @staticmethod
104+ @typecheck
105+ def from_yaml_file (path : str | Path ):
106+ config_dict = yaml_config_path_to_dict (path )
107+ return TrainerConfig (** config_dict )
108+
109+ def create_instance (
110+ self ,
111+ dataset : Dataset ,
112+ fabric : Fabric | None = None ,
113+ test_dataset : Dataset | None = None ,
114+ optimizer : Optimizer | None = None ,
115+ scheduler : LRScheduler | None = None ,
116+ valid_dataset : Dataset | None = None ,
117+ map_dataset_input_fn : Callable | None = None ,
118+ ) -> Trainer :
119+
120+ trainer_kwargs = self .model_dump ()
121+
122+ alphafold3 = self .model .create_instance ()
123+
124+ trainer_kwargs .update (dict (
125+ model = alphafold3 ,
126+ dataset = dataset ,
127+ fabric = fabric ,
128+ test_dataset = test_dataset ,
129+ optimizer = optimizer ,
130+ scheduler = scheduler ,
131+ valid_dataset = valid_dataset ,
132+ map_dataset_input_fn = map_dataset_input_fn
133+ ))
134+
135+ trainer = Trainer (** trainer_kwargs )
136+ return trainer
137+
138+ def create_instance_from_yaml_file (
139+ path : str | Path ,
140+ ** kwargs
141+ ) -> Trainer :
142+
143+ trainer_config = TrainerConfig .from_yaml_file (path )
144+ return trainer_config .create_instance (** kwargs )
0 commit comments