Skip to content

Commit 8eab2b7

Browse files
voorhsSamoed
andauthored
Refactor/return regexp support (#134)
* return regex module * fix typing * return test for regex * rename modules * make naming consistent: retriever -> embedding; prediction -> decision * fix codestyle * update schem --------- Co-authored-by: Roman Solomatin <[email protected]>
1 parent abe8c2f commit 8eab2b7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+562
-928
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
from autointent import Context, Dataset
1212
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1313
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
14-
from autointent.metrics import PREDICTION_METRICS
14+
from autointent.metrics import DECISION_METRICS
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
1717
from autointent.utils import load_default_search_space, load_search_space
1818

1919
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
2020

2121
if TYPE_CHECKING:
22-
from autointent.modules.abc import DecisionModule, ScoringModule
22+
from autointent.modules.abc import BaseDecision, BaseScorer
2323

2424

2525
class Pipeline:
@@ -155,7 +155,7 @@ def fit(
155155
self._refit(context)
156156

157157
predictions = self.predict(context.data_handler.test_utterances())
158-
for metric_name, metric in PREDICTION_METRICS.items():
158+
for metric_name, metric in DECISION_METRICS.items():
159159
context.optimization_info.pipeline_metrics[metric_name] = metric(
160160
context.data_handler.test_labels(),
161161
predictions,
@@ -218,8 +218,8 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels:
218218
msg = "Pipeline in optimization mode cannot perform inference"
219219
raise RuntimeError(msg)
220220

221-
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
222-
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
221+
scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
222+
decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
223223

224224
scores = scoring_module.predict(utterances)
225225
return decision_module.predict(scores)
@@ -235,8 +235,8 @@ def _refit(self, context: Context) -> None:
235235
msg = "Pipeline in optimization mode cannot perform inference"
236236
raise RuntimeError(msg)
237237

238-
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
239-
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
238+
scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
239+
decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
240240

241241
context.data_handler.prepare_for_refit()
242242

@@ -258,9 +258,9 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
258258

259259
scores, scores_metadata = self.nodes[NodeType.scoring].module.predict_with_metadata(utterances) # type: ignore[union-attr]
260260
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr,arg-type]
261-
regexp_predictions, regexp_predictions_metadata = None, None
262-
if NodeType.regexp in self.nodes:
263-
regexp_predictions, regexp_predictions_metadata = self.nodes[NodeType.regexp].module.predict_with_metadata( # type: ignore[union-attr]
261+
regex_predictions, regex_predictions_metadata = None, None
262+
if NodeType.regex in self.nodes:
263+
regex_predictions, regex_predictions_metadata = self.nodes[NodeType.regex].module.predict_with_metadata( # type: ignore[union-attr]
264264
utterances,
265265
)
266266

@@ -269,9 +269,9 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
269269
output = InferencePipelineUtteranceOutput(
270270
utterance=utterance,
271271
prediction=predictions[idx],
272-
regexp_prediction=regexp_predictions[idx] if regexp_predictions is not None else None,
273-
regexp_prediction_metadata=regexp_predictions_metadata[idx]
274-
if regexp_predictions_metadata is not None
272+
regex_prediction=regex_predictions[idx] if regex_predictions is not None else None,
273+
regex_prediction_metadata=regex_predictions_metadata[idx]
274+
if regex_predictions_metadata is not None
275275
else None,
276276
score=scores[idx],
277277
score_metadata=scores_metadata[idx] if scores_metadata is not None else None,
@@ -280,7 +280,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
280280

281281
return InferencePipelineOutput(
282282
predictions=predictions,
283-
regexp_predictions=regexp_predictions,
283+
regex_predictions=regex_predictions,
284284
utterances=outputs,
285285
)
286286

autointent/_pipeline/_schemas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ class InferencePipelineUtteranceOutput(BaseModel):
1010

1111
utterance: str
1212
prediction: LabelWithOOS
13-
regexp_prediction: LabelWithOOS
14-
regexp_prediction_metadata: Any
13+
regex_prediction: LabelWithOOS
14+
regex_prediction_metadata: Any
1515
score: list[float]
1616
score_metadata: Any
1717

@@ -20,5 +20,5 @@ class InferencePipelineOutput(BaseModel):
2020
"""Output of the inference pipeline."""
2121

2222
predictions: ListOfLabelsWithOOS
23-
regexp_predictions: ListOfLabels | None = None
23+
regex_predictions: ListOfLabels | None = None
2424
utterances: list[InferencePipelineUtteranceOutput] | None = None

autointent/context/_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,5 @@ def has_saved_modules(self) -> bool:
172172
173173
:return: True if there are saved modules, False otherwise.
174174
"""
175-
node_types = ["regexp", "embedding", "scoring", "decision"]
175+
node_types = ["regex", "embedding", "scoring", "decision"]
176176
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)

autointent/context/data_handler/_data_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class RegexPatterns(TypedDict):
2020

2121
id: int
2222
"""Intent class id."""
23-
regexp_full_match: list[str]
23+
regex_full_match: list[str]
2424
"""Full match regex patterns."""
25-
regexp_partial_match: list[str]
25+
regex_partial_match: list[str]
2626
"""Partial match regex patterns."""
2727

2828

@@ -59,11 +59,11 @@ def __init__(
5959
elif scheme == "cv":
6060
self._split_cv()
6161

62-
self.regexp_patterns = [
62+
self.regex_patterns = [
6363
RegexPatterns(
6464
id=intent.id,
65-
regexp_full_match=intent.regexp_full_match,
66-
regexp_partial_match=intent.regexp_partial_match,
65+
regex_full_match=intent.regex_full_match,
66+
regex_partial_match=intent.regex_partial_match,
6767
)
6868
for intent in self.dataset.intents
6969
]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._data_models import Artifact, DecisionArtifact, RetrieverArtifact, ScorerArtifact
1+
from ._data_models import Artifact, DecisionArtifact, EmbeddingArtifact, ScorerArtifact
22
from ._optimization_info import OptimizationInfo
33

4-
__all__ = ["Artifact", "DecisionArtifact", "OptimizationInfo", "RetrieverArtifact", "ScorerArtifact"]
4+
__all__ = ["Artifact", "DecisionArtifact", "EmbeddingArtifact", "OptimizationInfo", "ScorerArtifact"]

autointent/context/optimization_info/_data_models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class Artifact(BaseModel):
1818
"""Base class for artifacts generated by pipeline nodes."""
1919

2020

21-
class RegexpArtifact(Artifact):
22-
"""Artifact containing results from the regexp node."""
21+
class RegexArtifact(Artifact):
22+
"""Artifact containing results from the regex node."""
2323

2424

25-
class RetrieverArtifact(Artifact):
25+
class EmbeddingArtifact(Artifact):
2626
"""
2727
Artifact containing details from the embedding node.
2828
@@ -68,9 +68,9 @@ def validate_node_name(value: str) -> str:
6868
:return: Validated node type string.
6969
:raises ValueError: If the node type is invalid.
7070
"""
71-
if value in [NodeType.embedding, NodeType.scoring, NodeType.decision, NodeType.regexp]:
71+
if value in [NodeType.embedding, NodeType.scoring, NodeType.decision, NodeType.regex]:
7272
return value
73-
msg = f"Unknown node_type: {value}. Expected one of ['regexp', 'embedding', 'scoring', 'decision']"
73+
msg = f"Unknown node_type: {value}. Expected one of ['regex', 'embedding', 'scoring', 'decision']"
7474
raise ValueError(msg)
7575

7676

@@ -83,8 +83,8 @@ class Artifacts(BaseModel):
8383

8484
model_config = ConfigDict(arbitrary_types_allowed=True)
8585

86-
regexp: list[RegexpArtifact] = []
87-
embedding: list[RetrieverArtifact] = []
86+
regex: list[RegexArtifact] = []
87+
embedding: list[EmbeddingArtifact] = []
8888
scoring: list[ScorerArtifact] = []
8989
decision: list[DecisionArtifact] = []
9090

@@ -135,7 +135,7 @@ class Trial(BaseModel):
135135
class Trials(BaseModel):
136136
"""Container for managing optimization trials for pipeline nodes."""
137137

138-
regexp: list[Trial] = []
138+
regex: list[Trial] = []
139139
embedding: list[Trial] = []
140140
scoring: list[Trial] = []
141141
decision: list[Trial] = []
@@ -172,8 +172,8 @@ def add_trial(self, node_type: str, trial: Trial) -> None:
172172
class TrialsIds(BaseModel):
173173
"""Representation of the best trial IDs for each pipeline node."""
174174

175-
regexp: int | None = None
176-
"""Best trial index for the regexp node."""
175+
regex: int | None = None
176+
"""Best trial index for the regex node."""
177177
embedding: int | None = None
178178
"""Best trial index for the embedding node."""
179179
scoring: int | None = None

autointent/context/optimization_info/_optimization_info.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@
1515
from autointent.custom_types import NodeType
1616
from autointent.schemas import EmbedderConfig
1717

18-
from ._data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
18+
from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials, TrialsIds
1919

2020
if TYPE_CHECKING:
21-
from autointent.modules.abc import Module
21+
from autointent.modules.abc import BaseModule
2222

2323

2424
@dataclass
2525
class ModulesList:
2626
"""Container for managing lists of modules for each node type."""
2727

28-
regexp: list["Module"] = field(default_factory=list)
29-
embedding: list["Module"] = field(default_factory=list)
30-
scoring: list["Module"] = field(default_factory=list)
31-
decision: list["Module"] = field(default_factory=list)
28+
regex: list["BaseModule"] = field(default_factory=list)
29+
embedding: list["BaseModule"] = field(default_factory=list)
30+
scoring: list["BaseModule"] = field(default_factory=list)
31+
decision: list["BaseModule"] = field(default_factory=list)
3232

33-
def get(self, node_type: str) -> list["Module"]:
33+
def get(self, node_type: str) -> list["BaseModule"]:
3434
"""
3535
Retrieve the list of modules for a specific node type.
3636
37-
:param node_type: The type of node (e.g., "regexp", "embedding").
37+
:param node_type: The type of node (e.g., "regex", "embedding").
3838
:return: List of modules for the specified node type.
3939
"""
4040
return getattr(self, node_type) # type: ignore[no-any-return]
4141

42-
def add_module(self, node_type: str, module: "Module") -> None:
42+
def add_module(self, node_type: str, module: "BaseModule") -> None:
4343
"""
4444
Add a module to the list for a specific node type.
4545
@@ -77,7 +77,7 @@ def log_module_optimization(
7777
metric_name: str,
7878
artifact: Artifact,
7979
module_dump_dir: str | None,
80-
module: "Module | None" = None,
80+
module: "BaseModule | None" = None,
8181
) -> None:
8282
"""
8383
Log optimization results for a module.
@@ -126,7 +126,7 @@ def _get_best_trial_idx(self, node_type: str) -> int | None:
126126
self._trials_best_ids.set_best_trial_idx(node_type, best_idx)
127127
return best_idx
128128

129-
def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifact | Artifact:
129+
def _get_best_artifact(self, node_type: str) -> EmbeddingArtifact | ScorerArtifact | Artifact:
130130
"""
131131
Retrieve the best artifact for a specific node type.
132132
@@ -146,7 +146,7 @@ def get_best_embedder(self) -> EmbedderConfig:
146146
147147
:return: Name of the best embedder.
148148
"""
149-
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type=NodeType.embedding) # type: ignore[assignment]
149+
best_retriever_artifact: EmbeddingArtifact = self._get_best_artifact(node_type=NodeType.embedding) # type: ignore[assignment]
150150
return best_retriever_artifact.config
151151

152152
def get_best_train_scores(self) -> NDArray[np.float64] | None:
@@ -219,7 +219,7 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode
219219
res.append(item if asdict else InferenceNodeConfig(**item)) # type: ignore[arg-type]
220220
return res # type: ignore[return-value]
221221

222-
def _get_best_module(self, node_type: str) -> "Module | None":
222+
def _get_best_module(self, node_type: str) -> "BaseModule | None":
223223
"""
224224
Retrieve the best module for a specific node type.
225225
@@ -231,7 +231,7 @@ def _get_best_module(self, node_type: str) -> "Module | None":
231231
return self.modules.get(node_type)[idx]
232232
return None
233233

234-
def get_best_modules(self) -> dict[NodeType, "Module"]:
234+
def get_best_modules(self) -> dict[NodeType, "BaseModule"]:
235235
"""
236236
Retrieve the best modules for all node types.
237237

autointent/custom_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class LogLevel(Enum):
5151
class NodeType(str, Enum):
5252
"""Enumeration of node types in the AutoIntent pipeline."""
5353

54-
regexp = "regexp"
54+
regex = "regex"
5555
embedding = "embedding"
5656
scoring = "scoring"
5757
decision = "decision"

autointent/generation/intents/description_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def create_intent_description(
3838
client: AsyncOpenAI,
3939
intent_name: str | None,
4040
utterances: list[str],
41-
regexp_patterns: list[str],
41+
regex_patterns: list[str],
4242
prompt: PromptDescription,
4343
model_name: str,
4444
) -> str:
@@ -48,7 +48,7 @@ async def create_intent_description(
4848
:param client: The OpenAI client instance used to communicate with the model.
4949
:param intent_name: The name of the intent to describe. If None, an empty string will be used.
5050
:param utterances: A list of example utterances related to the intent.
51-
:param regexp_patterns: A list of regular expression patterns associated with the intent.
51+
:param regex_patterns: A list of regular expression patterns associated with the intent.
5252
5353
:param prompt: A string template for the prompt, which must include placeholders for {intent_name}
5454
and {user_utterances} to format the content sent to the model.
@@ -58,12 +58,12 @@ async def create_intent_description(
5858
"""
5959
intent_name = intent_name if intent_name is not None else ""
6060
utterances = random.sample(utterances, min(5, len(utterances)))
61-
regexp_patterns = random.sample(regexp_patterns, min(3, len(regexp_patterns)))
61+
regex_patterns = random.sample(regex_patterns, min(3, len(regex_patterns)))
6262

6363
content = prompt.text.format(
6464
intent_name=intent_name,
6565
user_utterances="\n".join(utterances),
66-
regexp_patterns="\n".join(regexp_patterns),
66+
regex_patterns="\n".join(regex_patterns),
6767
)
6868
chat_completion = await client.chat.completions.create(
6969
messages=[{"role": "user", "content": content}],
@@ -102,13 +102,13 @@ async def generate_intent_descriptions(
102102
if intent.description is not None:
103103
continue
104104
utterances = intent_utterances.get(intent.id, [])
105-
regexp_patterns = intent.regexp_full_match + intent.regexp_partial_match
105+
regex_patterns = intent.regex_full_match + intent.regex_partial_match
106106
task = asyncio.create_task(
107107
create_intent_description(
108108
client=client,
109109
intent_name=intent.name,
110110
utterances=utterances,
111-
regexp_patterns=regexp_patterns,
111+
regex_patterns=regex_patterns,
112112
prompt=prompt,
113113
model_name=model_name,
114114
),

autointent/generation/intents/prompt_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class PromptDescription(BaseModel):
1414
Should include placeholders for {intent_name} and {user_utterances}.
1515
- `{intent_name}` will be replaced with the name of the intent.
1616
- `{user_utterances}` will be replaced with the user utterances related to the intent.
17-
- (optionally) `{regexp_patterns}` will be replaced with the regular expressions that match user utterances.
17+
- (optionally) `{regex_patterns}` will be replaced with the regular expressions that match user utterances.
1818
"""
1919

2020
@classmethod

0 commit comments

Comments
 (0)