3131from ._schemas import InferencePipelineOutput , InferencePipelineUtteranceOutput
3232
3333if TYPE_CHECKING :
34- from autointent .modules .base import BaseDecision , BaseScorer
34+ from autointent .modules .base import BaseDecision , BaseRegex , BaseScorer
3535
3636
3737class Pipeline :
@@ -41,7 +41,7 @@ def __init__(
4141 self ,
4242 nodes : list [NodeOptimizer ] | list [InferenceNode ],
4343 sampler : SamplerType = "brute" ,
44- seed : int = 42 ,
44+ seed : int | None = 42 ,
4545 ) -> None :
4646 """Initialize the pipeline optimizer.
4747
@@ -85,7 +85,7 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
8585 assert_never (config )
8686
8787 @classmethod
88- def from_search_space (cls , search_space : list [dict [str , Any ]] | Path | str , seed : int = 42 ) -> "Pipeline" :
88+ def from_search_space (cls , search_space : list [dict [str , Any ]] | Path | str , seed : int | None = 42 ) -> "Pipeline" :
8989 """Search space to pipeline optimizer.
9090
9191 Args:
@@ -101,7 +101,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
101101 return cls (nodes = nodes , seed = seed )
102102
103103 @classmethod
104- def from_preset (cls , name : SearchSpacePresets , seed : int = 42 ) -> "Pipeline" :
104+ def from_preset (cls , name : SearchSpacePresets , seed : int | None = 42 ) -> "Pipeline" :
105105 optimization_config = load_preset (name )
106106 config = OptimizationConfig (seed = seed , ** optimization_config )
107107 return cls .from_optimization_config (config = config )
@@ -186,7 +186,7 @@ def fit(
186186 msg = "Pipeline in inference mode cannot be fitted"
187187 raise RuntimeError (msg )
188188
189- context = Context ()
189+ context = Context (self . seed )
190190 context .set_dataset (dataset , self .data_config )
191191 context .configure_logging (self .logging_config )
192192 context .configure_transformer (self .embedder_config )
@@ -199,25 +199,43 @@ def fit(
199199 self ._logger .warning (
200200 "Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
201201 )
202+ elif context .logging_config .clear_ram and not context .logging_config .dump_modules :
203+ self ._logger .warning (
204+ "Test data is provided, but final metrics won't be calculated "
205+ "because fitted modules won't be saved neither in RAM nor in file system."
206+ "Change settings in LoggerConfig to obtain different behavior."
207+ )
202208
203209 if sampler is None :
204210 sampler = self .sampler
205211
206212 self ._fit (context , sampler )
207213
208- if context .is_ram_to_clear () :
214+ if context .logging_config . clear_ram and context . logging_config . dump_modules :
209215 nodes_configs = context .optimization_info .get_inference_nodes_config ()
210216 nodes_list = [InferenceNode .from_config (cfg ) for cfg in nodes_configs ]
211- else :
217+ elif not context . logging_config . clear_ram :
212218 modules_dict = context .optimization_info .get_best_modules ()
213219 nodes_list = [InferenceNode (module , node_type ) for node_type , module in modules_dict .items ()]
220+ else :
221+ self ._logger .info (
222+ "Skipping calculating final metrics because fitted modules weren't saved."
223+ "Change settings in LoggerConfig to obtain different behavior."
224+ )
225+ return context
214226
215- self .nodes = {node .node_type : node for node in nodes_list }
227+ self .nodes = {node .node_type : node for node in nodes_list if node . node_type != NodeType . embedding }
216228
217229 if refit_after :
218- # TODO reflect this refitting in dumped version of pipeline
219230 self ._refit (context )
220231
232+ self ._nodes_configs : dict [str , InferenceNodeConfig ] = {
233+ NodeType (cfg .node_type ): cfg
234+ for cfg in context .optimization_info .get_inference_nodes_config ()
235+ if cfg .node_type != NodeType .embedding
236+ }
237+ self ._dump_dir = context .logging_config .dirpath
238+
221239 if test_utterances is not None :
222240 predictions = self .predict (test_utterances )
223241 for metric_name , metric in DECISION_METRICS .items ():
@@ -229,6 +247,41 @@ def fit(
229247
230248 return context
231249
250+ def dump (self , path : str | Path | None = None ) -> None :
251+ if isinstance (path , str ):
252+ path = Path (path )
253+ elif path is None :
254+ if hasattr (self , "_dump_dir" ):
255+ path = self ._dump_dir
256+ else :
257+ msg = (
258+ "Either you didn't trained the pipeline yet or fitted modules weren't saved during optimization. "
259+ "Change settings in LoggerConfig and retrain the pipeline to obtain different behavior."
260+ )
261+ self ._logger .error (msg )
262+ raise RuntimeError (msg )
263+
264+ scoring_module : BaseScorer = self .nodes [NodeType .scoring ].module # type: ignore[assignment,union-attr]
265+ decision_module : BaseDecision = self .nodes [NodeType .decision ].module # type: ignore[assignment,union-attr]
266+
267+ scoring_dump_dir = str (path / "scoring_module" )
268+ decision_dump_dir = str (path / "decision_module" )
269+ scoring_module .dump (scoring_dump_dir )
270+ decision_module .dump (decision_dump_dir )
271+
272+ self ._nodes_configs [NodeType .scoring ].load_path = scoring_dump_dir
273+ self ._nodes_configs [NodeType .decision ].load_path = decision_dump_dir
274+
275+ if NodeType .regex in self .nodes :
276+ regex_module : BaseRegex = self .nodes [NodeType .regex ].module # type: ignore[assignment,union-attr]
277+ regex_dump_dir = str (path / "regex_module" )
278+ regex_module .dump (regex_dump_dir )
279+ self ._nodes_configs [NodeType .regex ].load_path = regex_dump_dir
280+
281+ inference_nodes_configs = [cfg .asdict () for cfg in self ._nodes_configs .values ()]
282+ with (path / "inference_config.yaml" ).open ("w" ) as file :
283+ yaml .dump (inference_nodes_configs , file )
284+
232285 def validate_modules (self , dataset : Dataset , mode : SearchSpaceValidationMode ) -> None :
233286 """Validate modules with dataset.
234287
@@ -240,18 +293,6 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) ->
240293 if isinstance (node , NodeOptimizer ):
241294 node .validate_nodes_with_dataset (dataset , mode )
242295
243- @classmethod
244- def from_dict_config (cls , nodes_configs : list [dict [str , Any ]]) -> "Pipeline" :
245- """Create inference pipeline from dictionary config.
246-
247- Args:
248- nodes_configs: list of config for nodes
249-
250- Returns:
251- Inference pipeline
252- """
253- return cls .from_config ([InferenceNodeConfig (** cfg ) for cfg in nodes_configs ])
254-
255296 @classmethod
256297 def from_config (cls , nodes_configs : list [InferenceNodeConfig ]) -> "Pipeline" :
257298 """Create inference pipeline from config.
@@ -283,13 +324,13 @@ def load(
283324 Inference pipeline
284325 """
285326 with (Path (path ) / "inference_config.yaml" ).open () as file :
286- inference_dict_config : dict [str , Any ] = yaml .safe_load (file )
327+ inference_nodes_configs : list [ dict [str , Any ] ] = yaml .safe_load (file )
287328
288329 inference_config = [
289330 InferenceNodeConfig (
290331 ** node_config , embedder_config = embedder_config , cross_encoder_config = cross_encoder_config
291332 )
292- for node_config in inference_dict_config [ "nodes_configs" ]
333+ for node_config in inference_nodes_configs
293334 ]
294335
295336 return cls .from_config (inference_config )
0 commit comments