1+ import json
12import re
2- from typing import TypedDict
3+ from pathlib import Path
4+ from typing import Any , TypedDict
35
46from typing_extensions import Self
57
68from autointent import Context
79from autointent .context .data_handler .data_handler import RegexPatterns
10+ from autointent .context .data_handler .schemas import Intent
811from autointent .context .optimization_info .data_models import Artifact
912from autointent .custom_types import LabelType
1013from autointent .metrics .regexp import RegexpMetricFn
@@ -19,43 +22,60 @@ class RegexPatternsCompiled(TypedDict):
1922
2023
2124class RegExp (Module ):
22- name = "regexp"
23-
24- def __init__ (self , regexp_patterns : list [RegexPatterns ]) -> None :
25- self .regexp_patterns = regexp_patterns
26-
2725 @classmethod
2826 def from_context (cls , context : Context ) -> Self :
29- return cls (
30- regexp_patterns = context .data_handler .regexp_patterns ,
31- )
32-
33- def fit (self , utterances : list [str ], labels : list [LabelType ]) -> None :
34- self .regexp_patterns_compiled : list [RegexPatternsCompiled ] = [
35- {
36- "id" : dct ["id" ],
37- "regexp_full_match" : [re .compile (ptn , flags = re .IGNORECASE ) for ptn in dct ["regexp_full_match" ]],
38- "regexp_partial_match" : [re .compile (ptn , flags = re .IGNORECASE ) for ptn in dct ["regexp_partial_match" ]],
39- }
40- for dct in self .regexp_patterns
27+ return cls ()
28+
29+ def fit (self , intents : list [dict [str , Any ]]) -> None :
30+ intents_parsed = [Intent (** dct ) for dct in intents ]
31+ self .regexp_patterns = [
32+ RegexPatterns (
33+ id = intent .id ,
34+ regexp_full_match = intent .regexp_full_match ,
35+ regexp_partial_match = intent .regexp_partial_match ,
36+ )
37+ for intent in intents_parsed
4138 ]
39+ self ._compile_regex_patterns ()
4240
4341 def predict (self , utterances : list [str ]) -> list [LabelType ]:
44- return [list (self ._predict_single (ut )) for ut in utterances ]
45-
46- def _match (self , text : str , intent_record : RegexPatternsCompiled ) -> bool :
47- full_match = any (ptn .fullmatch (text ) for ptn in intent_record ["regexp_full_match" ])
48- if full_match :
49- return True
50- return any (ptn .match (text ) for ptn in intent_record ["regexp_partial_match" ])
42+ return [self ._predict_single (utterance )[0 ] for utterance in utterances ]
43+
44+ def predict_with_metadata (
45+ self ,
46+ utterances : list [str ],
47+ ) -> tuple [list [LabelType ], list [dict [str , Any ]] | None ]:
48+ predictions , metadata = [], []
49+ for utterance in utterances :
50+ prediction , matches = self ._predict_single (utterance )
51+ predictions .append (prediction )
52+ metadata .append (matches )
53+ return predictions , metadata
54+
55+ def _match (self , utterance : str , intent_record : RegexPatternsCompiled ) -> dict [str , list [str ]]:
56+ full_matches = [
57+ pattern .pattern
58+ for pattern in intent_record ["regexp_full_match" ]
59+ if pattern .fullmatch (utterance ) is not None
60+ ]
61+ partial_matches = [
62+ pattern .pattern
63+ for pattern in intent_record ["regexp_partial_match" ]
64+ if pattern .search (utterance ) is not None
65+ ]
66+ return {"full_matches" : full_matches , "partial_matches" : partial_matches }
5167
52- def _predict_single (self , utterance : str ) -> set [ int ]:
68+ def _predict_single (self , utterance : str ) -> tuple [ LabelType , dict [ str , list [ str ]] ]:
5369 # todo test this
54- return {
55- intent_record ["id" ]
56- for intent_record in self .regexp_patterns_compiled
57- if self ._match (utterance , intent_record )
58- }
70+ prediction = set ()
71+ matches : dict [str , list [str ]] = {"full_matches" : [], "partial_matches" : []}
72+ for intent_record in self .regexp_patterns_compiled :
73+ intent_matches = self ._match (utterance , intent_record )
74+ if intent_matches ["full_matches" ] or intent_matches ["partial_matches" ]:
75+ prediction .add (intent_record ["id" ])
76+ matches ["full_matches" ].extend (intent_matches ["full_matches" ])
77+ matches ["partial_matches" ].extend (intent_matches ["partial_matches" ])
78+ return list (prediction ), matches
5979
6080 def score (self , context : Context , metric_fn : RegexpMetricFn ) -> float :
6181 # TODO add parameter to a whole pipeline (or just to regexp module):
@@ -78,7 +98,29 @@ def get_assets(self) -> Artifact:
7898 return Artifact ()
7999
80100 def load (self , path : str ) -> None :
81- pass
101+ dump_dir = Path (path )
102+
103+ with (dump_dir / self .metadata_dict_name ).open () as file :
104+ self .regexp_patterns = json .load (file )
105+
106+ self ._compile_regex_patterns ()
82107
83108 def dump (self , path : str ) -> None :
84- pass
109+ dump_dir = Path (path )
110+
111+ with (dump_dir / self .metadata_dict_name ).open ("w" ) as file :
112+ json .dump (self .regexp_patterns , file , indent = 4 )
113+
114+ def _compile_regex_patterns (self ) -> None :
115+ self .regexp_patterns_compiled = [
116+ RegexPatternsCompiled (
117+ id = regexp_patterns ["id" ],
118+ regexp_full_match = [
119+ re .compile (pattern , flags = re .IGNORECASE ) for pattern in regexp_patterns ["regexp_full_match" ]
120+ ],
121+ regexp_partial_match = [
122+ re .compile (ptn , flags = re .IGNORECASE ) for ptn in regexp_patterns ["regexp_partial_match" ]
123+ ],
124+ )
125+ for regexp_patterns in self .regexp_patterns
126+ ]
0 commit comments