33import json
44import logging
55from pathlib import Path
6- from typing import TYPE_CHECKING , Any
6+ from typing import TYPE_CHECKING , Any , get_args
77
88import numpy as np
99import yaml
1313from autointent .custom_types import ListOfGenericLabels , NodeType , SamplerType
1414from autointent .metrics import DECISION_METRICS
1515from autointent .nodes import InferenceNode , NodeOptimizer
16- from autointent .nodes .schemes import OptimizationConfig
16+ from autointent .nodes .schemes import OptimizationConfig , OptimizationSearchSpaceConfig
1717from autointent .utils import load_default_search_space , load_search_space
1818
1919from ._schemas import InferencePipelineOutput , InferencePipelineUtteranceOutput
@@ -28,17 +28,24 @@ class Pipeline:
2828 def __init__ (
2929 self ,
3030 nodes : list [NodeOptimizer ] | list [InferenceNode ],
31+ sampler : SamplerType = "brute" ,
3132 seed : int = 42 ,
3233 ) -> None :
3334 """
3435 Initialize the pipeline optimizer.
3536
3637 :param nodes: list of nodes
38+ :param sampler: sampler type
3739 :param seed: random seed
3840 """
3941 self ._logger = logging .getLogger (__name__ )
4042 self .nodes = {node .node_type : node for node in nodes }
4143 self .seed = seed
44+ if sampler not in get_args (SamplerType ):
45+ msg = f"Sampler should be one of { get_args (SamplerType )} "
46+ raise ValueError (msg )
47+
48+ self .sampler = sampler
4249
4350 if isinstance (nodes [0 ], NodeOptimizer ):
4451 self .logging_config = LoggingConfig (dump_dir = None )
@@ -74,10 +81,34 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
7481 """
7582 if isinstance (search_space , Path | str ):
7683 search_space = load_search_space (search_space )
77- validated_search_space = OptimizationConfig (search_space ).model_dump () # type: ignore[arg-type]
84+ validated_search_space = OptimizationSearchSpaceConfig (search_space ).model_dump () # type: ignore[arg-type]
7885 nodes = [NodeOptimizer (** node ) for node in validated_search_space ]
7986 return cls (nodes = nodes , seed = seed )
8087
88+ @classmethod
89+ def from_optimization_config (cls , config : dict [str , Any ] | Path | str ) -> "Pipeline" :
90+ """
91+ Create pipeline optimizer from optimization config.
92+
93+ :param config: Optimization config
94+ :return:
95+ """
96+ if isinstance (config , Path | str ):
97+ with Path (config ).open () as file :
98+ loaded_config = yaml .safe_load (file )
99+ else :
100+ loaded_config = config
101+ optimization_config = OptimizationConfig (** loaded_config )
102+ pipeline = cls (
103+ [NodeOptimizer (** node .model_dump ()) for node in optimization_config .task_config .search_space ],
104+ optimization_config .task_config .sampler ,
105+ optimization_config .seed ,
106+ )
107+ pipeline .set_config (optimization_config .logging_config )
108+ pipeline .set_config (optimization_config .vector_index_config )
109+ pipeline .set_config (optimization_config .data_config )
110+ return pipeline
111+
81112 @classmethod
82113 def default_optimizer (cls , multilabel : bool , seed : int = 42 ) -> "Pipeline" :
83114 """
@@ -90,7 +121,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
90121 """
91122 return cls .from_search_space (search_space = load_default_search_space (multilabel ), seed = seed )
92123
93- def _fit (self , context : Context , sampler : SamplerType = "brute" ) -> None :
124+ def _fit (self , context : Context , sampler : SamplerType ) -> None :
94125 """
95126 Optimize the pipeline.
96127
@@ -99,7 +130,7 @@ def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
99130 self .context = context
100131 self ._logger .info ("starting pipeline optimization..." )
101132 self .context .callback_handler .start_run (
102- run_name = self .context .logging_config .run_name ,
133+ run_name = self .context .logging_config .get_run_name () ,
103134 dirpath = self .context .logging_config .dirpath ,
104135 )
105136 for node_type in NodeType :
@@ -123,7 +154,7 @@ def fit(
123154 self ,
124155 dataset : Dataset ,
125156 refit_after : bool = False ,
126- sampler : SamplerType = "brute" ,
157+ sampler : SamplerType | None = None ,
127158 ) -> Context :
128159 """
129160 Optimize the pipeline from dataset.
@@ -148,6 +179,9 @@ def fit(
148179 "Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
149180 )
150181
182+ if sampler is None :
183+ sampler = self .sampler or "brute"
184+
151185 self ._fit (context , sampler )
152186
153187 if context .is_ram_to_clear ():
0 commit comments