@@ -74,7 +74,6 @@ def log_module_optimization(
7474 module_params : dict [str , Any ],
7575 metric_value : float ,
7676 metric_name : str ,
77- artifact : Artifact ,
7877 module_dump_dir : str | None ,
7978 module : "Module | None" = None ,
8079 ) -> None :
@@ -103,13 +102,11 @@ def log_module_optimization(
103102 if module :
104103 self .modules .add_module (node_type , module )
105104
106- self .artifacts .add_artifact (node_type , artifact )
107-
108105 def _get_metrics_values (self , node_type : str ) -> list [float ]:
109106 """Retrieve all metric values for a specific node type."""
110107 return [trial .metric_value for trial in self .trials .get_trials (node_type )]
111108
112- def _get_best_trial_idx (self , node_type : str ) -> int | None :
109+ def get_best_trial_idx (self , node_type : str ) -> int | None :
113110 """
114111 Retrieve the index of the best trial for a node type.
115112
@@ -133,7 +130,7 @@ def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifa
133130 :return: The best artifact for the node type.
134131 :raises ValueError: If no best trial exists for the node type.
135132 """
136- best_idx = self ._get_best_trial_idx (node_type )
133+ best_idx = self .get_best_trial_idx (node_type )
137134 if best_idx is None :
138135 msg = f"No best trial for { node_type } "
139136 raise ValueError (msg )
@@ -194,7 +191,7 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode
194191
195192 :return: List of `InferenceNodeConfig` objects for inference nodes.
196193 """
197- trial_ids = [self ._get_best_trial_idx (node_type ) for node_type in NodeType ]
194+ trial_ids = [self .get_best_trial_idx (node_type ) for node_type in NodeType ]
198195 res = []
199196 for idx , node_type in zip (trial_ids , NodeType , strict = True ):
200197 if idx is None :
@@ -216,7 +213,7 @@ def _get_best_module(self, node_type: str) -> "Module | None":
216213 :param node_type: Type of the node.
217214 :return: The best module, or None if no best trial exists.
218215 """
219- idx = self ._get_best_trial_idx (node_type )
216+ idx = self .get_best_trial_idx (node_type )
220217 if idx is not None :
221218 return self .modules .get (node_type )[idx ]
222219 return None
0 commit comments