1111from autointent import Context , Dataset
1212from autointent .configs import InferenceNodeConfig , LoggingConfig , VectorIndexConfig
1313from autointent .custom_types import ListOfGenericLabels , NodeType , SamplerType , ValidationScheme
14- from autointent .metrics import PREDICTION_METRICS
14+ from autointent .metrics import DECISION_METRICS
1515from autointent .nodes import InferenceNode , NodeOptimizer
1616from autointent .nodes .schemes import OptimizationConfig
1717from autointent .utils import load_default_search_space , load_search_space
1818
1919from ._schemas import InferencePipelineOutput , InferencePipelineUtteranceOutput
2020
2121if TYPE_CHECKING :
22- from autointent .modules .abc import DecisionModule , ScoringModule
22+ from autointent .modules .abc import BaseDecision , BaseScorer
2323
2424
2525class Pipeline :
@@ -155,7 +155,7 @@ def fit(
155155 self ._refit (context )
156156
157157 predictions = self .predict (context .data_handler .test_utterances ())
158- for metric_name , metric in PREDICTION_METRICS .items ():
158+ for metric_name , metric in DECISION_METRICS .items ():
159159 context .optimization_info .pipeline_metrics [metric_name ] = metric (
160160 context .data_handler .test_labels (),
161161 predictions ,
@@ -218,8 +218,8 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels:
218218 msg = "Pipeline in optimization mode cannot perform inference"
219219 raise RuntimeError (msg )
220220
221- scoring_module : ScoringModule = self .nodes [NodeType .scoring ].module # type: ignore[assignment,union-attr]
222- decision_module : DecisionModule = self .nodes [NodeType .decision ].module # type: ignore[assignment,union-attr]
221+ scoring_module : BaseScorer = self .nodes [NodeType .scoring ].module # type: ignore[assignment,union-attr]
222+ decision_module : BaseDecision = self .nodes [NodeType .decision ].module # type: ignore[assignment,union-attr]
223223
224224 scores = scoring_module .predict (utterances )
225225 return decision_module .predict (scores )
@@ -235,8 +235,8 @@ def _refit(self, context: Context) -> None:
235235 msg = "Pipeline in optimization mode cannot perform inference"
236236 raise RuntimeError (msg )
237237
238- scoring_module : ScoringModule = self .nodes [NodeType .scoring ].module # type: ignore[assignment,union-attr]
239- decision_module : DecisionModule = self .nodes [NodeType .decision ].module # type: ignore[assignment,union-attr]
238+ scoring_module : BaseScorer = self .nodes [NodeType .scoring ].module # type: ignore[assignment,union-attr]
239+ decision_module : BaseDecision = self .nodes [NodeType .decision ].module # type: ignore[assignment,union-attr]
240240
241241 context .data_handler .prepare_for_refit ()
242242
@@ -258,9 +258,9 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
258258
259259 scores , scores_metadata = self .nodes [NodeType .scoring ].module .predict_with_metadata (utterances ) # type: ignore[union-attr]
260260 predictions = self .nodes [NodeType .decision ].module .predict (scores ) # type: ignore[union-attr,arg-type]
261- regexp_predictions , regexp_predictions_metadata = None , None
262- if NodeType .regexp in self .nodes :
263- regexp_predictions , regexp_predictions_metadata = self .nodes [NodeType .regexp ].module .predict_with_metadata ( # type: ignore[union-attr]
261+ regex_predictions , regex_predictions_metadata = None , None
262+ if NodeType .regex in self .nodes :
263+ regex_predictions , regex_predictions_metadata = self .nodes [NodeType .regex ].module .predict_with_metadata ( # type: ignore[union-attr]
264264 utterances ,
265265 )
266266
@@ -269,9 +269,9 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
269269 output = InferencePipelineUtteranceOutput (
270270 utterance = utterance ,
271271 prediction = predictions [idx ],
272- regexp_prediction = regexp_predictions [idx ] if regexp_predictions is not None else None ,
273- regexp_prediction_metadata = regexp_predictions_metadata [idx ]
274- if regexp_predictions_metadata is not None
272+ regex_prediction = regex_predictions [idx ] if regex_predictions is not None else None ,
273+ regex_prediction_metadata = regex_predictions_metadata [idx ]
274+ if regex_predictions_metadata is not None
275275 else None ,
276276 score = scores [idx ],
277277 score_metadata = scores_metadata [idx ] if scores_metadata is not None else None ,
@@ -280,7 +280,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
280280
281281 return InferencePipelineOutput (
282282 predictions = predictions ,
283- regexp_predictions = regexp_predictions ,
283+ regex_predictions = regex_predictions ,
284284 utterances = outputs ,
285285 )
286286
0 commit comments