11"""Module for regular expressions based intent detection."""
22
33import re
4+ from collections .abc import Iterable
45from typing import Any , TypedDict
56
7+ import numpy as np
8+ import numpy .typing as npt
9+
610from autointent import Context
711from autointent .context .data_handler ._data_handler import RegexPatterns
812from autointent .context .optimization_info import Artifact
9- from autointent .custom_types import LabelType
13+ from autointent .custom_types import LabelType , ListOfGenericLabels , ListOfLabels
1014from autointent .metrics import REGEX_METRICS
1115from autointent .modules .base import BaseRegex
1216from autointent .schemas import Intent
@@ -36,7 +40,10 @@ class Regex(BaseRegex):
3640 name: Name of the module, defaults to "regex"
3741 """
3842
39- name = "regex"
43+ name = "simple"
44+ supports_multiclass = True
45+ supports_multilabel = True
46+ supports_oos = False
4047
4148 @classmethod
4249 def from_context (cls , context : Context ) -> "Regex" :
@@ -158,7 +165,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
158165 return self .score_metrics_ho ((val_labels , pred_labels ), chosen_metrics )
159166
160167 def score_cv (self , context : Context , metrics : list [str ]) -> dict [str , float ]:
161- """Score the model using cross-validation.
168+ """Score the model in cross-validation mode .
162169
163170 Args:
164171 context: Context containing validation data
@@ -169,10 +176,42 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
169176 """
170177 chosen_metrics = {name : fn for name , fn in REGEX_METRICS .items () if name in metrics }
171178
172- metrics_calculated , _ = self .score_metrics_cv (chosen_metrics , context .data_handler .validation_iterator ())
179+ metrics_calculated , _ = self .score_metrics_cv (
180+ chosen_metrics , context .data_handler .validation_iterator (), intents = context .data_handler .dataset .intents
181+ )
173182
174183 return metrics_calculated
175184
185+ def score_metrics_cv (
186+ self ,
187+ metrics_dict : dict [str , Any ],
188+ cv_iterator : Iterable [tuple [list [str ], ListOfLabels , list [str ], ListOfLabels ]],
189+ intents : list [Intent ],
190+ ) -> tuple [dict [str , float ], list [ListOfGenericLabels ] | list [npt .NDArray [Any ]]]:
191+ """Score metrics using cross-validation.
192+
193+ Args:
194+ metrics_dict: Dictionary of metrics to compute
195+ cv_iterator: Cross-validation iterator
196+ intents: intents from the dataset
197+
198+ Returns:
199+ Tuple of metrics dictionary and predictions
200+ """
201+ metrics_values : dict [str , list [float ]] = {name : [] for name in metrics_dict }
202+ all_val_preds = []
203+
204+ self .fit (intents )
205+
206+ for _ , _ , val_utterances , val_labels in cv_iterator :
207+ val_preds = self .predict (val_utterances )
208+ for name , fn in metrics_dict .items ():
209+ metrics_values [name ].append (fn (val_labels , val_preds ))
210+ all_val_preds .append (val_preds )
211+
212+ metrics = {name : float (np .mean (values_list )) for name , values_list in metrics_values .items ()}
213+ return metrics , all_val_preds # type: ignore[return-value]
214+
176215 def clear_cache (self ) -> None :
177216 """Clear cached regex patterns."""
178217 del self .regex_patterns
0 commit comments