11"""Module for regular expressions based intent detection."""
22
33import re
4- from typing import Any , Literal , TypedDict
4+ from typing import Any , TypedDict
55
66from autointent import Context
77from autointent .context .data_handler ._data_handler import RegexPatterns
88from autointent .context .optimization_info import Artifact
99from autointent .custom_types import LabelType
10- from autointent .metrics import REGEXP_METRICS
10+ from autointent .metrics import REGEX_METRICS
1111from autointent .modules .abc import RegexModule
1212from autointent .schemas import Intent
1313
@@ -33,23 +33,19 @@ def from_context(cls, context: Context) -> "Regex":
3333 """Initialize from context."""
3434 return cls ()
3535
36- def get_train_data (self , context : Context ) -> list [Intent ]:
37- return context .data_handler .dataset .intents
38-
39- def fit (self , intents : list [dict [str , Any ]]) -> None :
36+ def fit (self , intents : list [Intent ]) -> None :
4037 """
4138 Fit the model.
4239
4340 :param intents: Intents to fit
4441 """
45- intents_parsed = [Intent (** dct ) for dct in intents ]
4642 self .regex_patterns = [
4743 RegexPatterns (
4844 id = intent .id ,
4945 regex_full_match = intent .regex_full_match ,
5046 regex_partial_match = intent .regex_partial_match ,
5147 )
52- for intent in intents_parsed
48+ for intent in intents
5349 ]
5450 self ._compile_regex_patterns ()
5551
@@ -109,24 +105,32 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str
109105 matches ["partial_matches" ].extend (intent_matches ["partial_matches" ])
110106 return list (prediction ), matches
111107
112- def score (self , context : Context , split : Literal ["validation" , "test" ], metrics : list [str ]) -> dict [str , float ]:
108+ def score_ho (self , context : Context , metrics : list [str ]) -> dict [str , float ]:
109+ self .fit (context .data_handler .dataset .intents )
110+
111+ val_utterances = context .data_handler .validation_utterances (0 )
112+ val_labels = context .data_handler .validation_labels (0 )
113+
114+ pred_labels = self .predict (val_utterances )
115+
116+ chosen_metrics = {name : fn for name , fn in REGEX_METRICS .items () if name in metrics }
117+ return self .score_metrics_ho ((val_labels , pred_labels ), chosen_metrics )
118+
119+ def score_cv (self , context : Context , metrics : list [str ]) -> dict [str , float ]:
113120 """
114- Calculate metric on test set and return metric value .
121+ Evaluate the scorer on a test set and compute the specified metric .
115122
116- :param context: Context to score
117- :param split: Split to score on
123+ :param context: Context containing test set and other data.
124+ :param split: Target split
118125 :return: Computed metrics value for the test set or error code of metrics
119126 """
120- # TODO add parameter to a whole pipeline (or just to regex module):
121- # whether or not to omit utterances on next stages if they were detected with regex module
122- assets = {
123- "test_matches" : list (self .predict (context .data_handler .test_utterances ())),
124- }
125- if assets ["test_matches" ] is None :
126- msg = "no matches found"
127- raise ValueError (msg )
128- chosen_metrics = {name : fn for name , fn in REGEXP_METRICS .items () if name in metrics }
129- return self .score_metrics ((context .data_handler .test_labels (), assets ["test_matches" ]), chosen_metrics )
127+ chosen_metrics = {name : fn for name , fn in REGEX_METRICS .items () if name in metrics }
128+
129+ metrics_calculated , _ = self .score_metrics_cv (
130+ chosen_metrics , context .data_handler .validation_iterator ()
131+ )
132+
133+ return metrics_calculated
130134
131135 def clear_cache (self ) -> None :
132136 """Clear cache."""
0 commit comments