33import json
44import logging
55from pathlib import Path
6- from typing import TYPE_CHECKING , Any , get_args
6+ from typing import TYPE_CHECKING , Any
77
88import numpy as np
99import yaml
1515 DataConfig ,
1616 EmbedderConfig ,
1717 HFModelConfig ,
18+ HPOConfig ,
1819 InferenceNodeConfig ,
1920 LoggingConfig ,
2021)
2122from autointent .custom_types import (
2223 ListOfGenericLabels ,
2324 NodeType ,
24- SamplerType ,
2525 SearchSpacePreset ,
2626 SearchSpaceValidationMode ,
2727)
@@ -44,7 +44,6 @@ class Pipeline:
4444 def __init__ (
4545 self ,
4646 nodes : list [NodeOptimizer ] | list [InferenceNode ],
47- sampler : SamplerType = "brute" ,
4847 seed : int | None = 42 ,
4948 ) -> None :
5049 """Initialize the pipeline optimizer.
@@ -57,23 +56,19 @@ def __init__(
5756 self ._logger = logging .getLogger (__name__ )
5857 self .nodes = {node .node_type : node for node in nodes }
5958 self ._seed = seed
60- if sampler not in get_args (SamplerType ):
61- msg = f"Sampler should be one of { get_args (SamplerType )} "
62- raise ValueError (msg )
63-
64- self ._sampler = sampler
6559
6660 if isinstance (nodes [0 ], NodeOptimizer ):
6761 self .logging_config = LoggingConfig ()
6862 self .embedder_config = EmbedderConfig ()
6963 self .cross_encoder_config = CrossEncoderConfig ()
7064 self .data_config = DataConfig ()
7165 self .transformer_config = HFModelConfig ()
66+ self .hpo_config = HPOConfig ()
7267 elif not isinstance (nodes [0 ], InferenceNode ):
7368 assert_never (nodes )
7469
7570 def set_config (
76- self , config : LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig
71+ self , config : LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig | HPOConfig
7772 ) -> None :
7873 """Set the configuration for the pipeline.
7974
@@ -90,6 +85,8 @@ def set_config(
9085 self .data_config = config
9186 elif isinstance (config , HFModelConfig ):
9287 self .transformer_config = config
88+ elif isinstance (config , HPOConfig ):
89+ self .hpo_config = config
9390 else :
9491 assert_never (config )
9592
@@ -126,23 +123,23 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
126123 if isinstance (config , dict ):
127124 dict_params = config
128125 else :
129- with Path (config ).open () as file :
126+ with Path (config ).open (encoding = "utf-8" ) as file :
130127 dict_params = yaml .safe_load (file )
131128 optimization_config = OptimizationConfig (** dict_params )
132129
133130 pipeline = cls (
134131 [NodeOptimizer (** node ) for node in optimization_config .search_space ],
135- optimization_config .sampler ,
136132 optimization_config .seed ,
137133 )
138134 pipeline .set_config (optimization_config .logging_config )
139135 pipeline .set_config (optimization_config .data_config )
140136 pipeline .set_config (optimization_config .embedder_config )
141137 pipeline .set_config (optimization_config .cross_encoder_config )
142138 pipeline .set_config (optimization_config .transformer_config )
139+ pipeline .set_config (optimization_config .hpo_config )
143140 return pipeline
144141
145- def _fit (self , context : Context , sampler : SamplerType ) -> None :
142+ def _fit (self , context : Context ) -> None :
146143 """Optimize the pipeline.
147144
148145 Args:
@@ -167,7 +164,7 @@ def _fit(self, context: Context, sampler: SamplerType) -> None:
167164 for node_type in NodeType :
168165 node_optimizer = self .nodes .get (node_type , None )
169166 if node_optimizer is not None :
170- node_optimizer .fit (context , sampler ) # type: ignore[union-attr]
167+ node_optimizer .fit (context ) # type: ignore[union-attr]
171168 self .context .callback_handler .end_run ()
172169
173170 def _is_inference (self ) -> bool :
@@ -182,7 +179,6 @@ def fit(
182179 self ,
183180 dataset : Dataset ,
184181 refit_after : bool = False ,
185- sampler : SamplerType | None = None ,
186182 incompatible_search_space : SearchSpaceValidationMode = "filter" ,
187183 ) -> Context :
188184 """Optimize the pipeline from dataset.
@@ -206,6 +202,7 @@ def fit(
206202 context .configure_transformer (self .embedder_config )
207203 context .configure_transformer (self .cross_encoder_config )
208204 context .configure_transformer (self .transformer_config )
205+ context .configure_hpo (self .hpo_config )
209206
210207 self .validate_modules (dataset , mode = incompatible_search_space )
211208
@@ -221,10 +218,7 @@ def fit(
221218 "Change settings in LoggerConfig to obtain different behavior."
222219 )
223220
224- if sampler is None :
225- sampler = self ._sampler
226-
227- self ._fit (context , sampler )
221+ self ._fit (context )
228222
229223 if context .logging_config .clear_ram and context .logging_config .dump_modules :
230224 nodes_configs = context .optimization_info .get_inference_nodes_config ()
@@ -336,7 +330,7 @@ def load(
336330 embedder_config: one can override presaved settings
337331 cross_encoder_config: one can override presaved settings
338332 """
339- with (Path (path ) / "inference_config.yaml" ).open () as file :
333+ with (Path (path ) / "inference_config.yaml" ).open (encoding = "utf-8" ) as file :
340334 inference_nodes_configs : list [dict [str , Any ]] = yaml .safe_load (file )
341335
342336 inference_config = [
0 commit comments