Skip to content

Commit de779c0

Browse files
truff4utvoorhsSamoedDarinochka
authored
Add enhanced output (#38)
Co-authored-by: voorhs <[email protected]> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Darinka <[email protected]>
1 parent ad097e8 commit de779c0

File tree

29 files changed

+694
-255
lines changed

29 files changed

+694
-255
lines changed

autointent/configs/inference_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class InferenceConfig:
1111
source_dir: str
1212
output_path: str
1313
log_level: LogLevel = LogLevel.ERROR
14+
with_metadata: bool = False
1415

1516

1617
cs = ConfigStore.instance()

autointent/context/embedder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,8 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
8181
if self.max_length is not None:
8282
self.embedding_model.max_seq_length = self.max_length
8383
return self.embedding_model.encode(
84-
utterances, convert_to_numpy=True, batch_size=self.batch_size, normalize_embeddings=True,
84+
utterances,
85+
convert_to_numpy=True,
86+
batch_size=self.batch_size,
87+
normalize_embeddings=True,
8588
) # type: ignore[return-value]

autointent/modules/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def load(self, path: str) -> None:
4848
def predict(self, *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any]) -> npt.NDArray[Any]:
4949
"""inference"""
5050

51+
def predict_with_metadata(
52+
self,
53+
*args: list[str] | npt.NDArray[Any],
54+
**kwargs: dict[str, Any],
55+
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
56+
return self.predict(*args, **kwargs), None
57+
5158
@classmethod
5259
@abstractmethod
5360
def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> Self:

autointent/modules/regexp.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import json
12
import re
2-
from typing import TypedDict
3+
from pathlib import Path
4+
from typing import Any, TypedDict
35

46
from typing_extensions import Self
57

68
from autointent import Context
79
from autointent.context.data_handler.data_handler import RegexPatterns
10+
from autointent.context.data_handler.schemas import Intent
811
from autointent.context.optimization_info.data_models import Artifact
912
from autointent.custom_types import LabelType
1013
from autointent.metrics.regexp import RegexpMetricFn
@@ -19,43 +22,60 @@ class RegexPatternsCompiled(TypedDict):
1922

2023

2124
class RegExp(Module):
22-
name = "regexp"
23-
24-
def __init__(self, regexp_patterns: list[RegexPatterns]) -> None:
25-
self.regexp_patterns = regexp_patterns
26-
2725
@classmethod
2826
def from_context(cls, context: Context) -> Self:
29-
return cls(
30-
regexp_patterns=context.data_handler.regexp_patterns,
31-
)
32-
33-
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
34-
self.regexp_patterns_compiled: list[RegexPatternsCompiled] = [
35-
{
36-
"id": dct["id"],
37-
"regexp_full_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_full_match"]],
38-
"regexp_partial_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_partial_match"]],
39-
}
40-
for dct in self.regexp_patterns
27+
return cls()
28+
29+
def fit(self, intents: list[dict[str, Any]]) -> None:
30+
intents_parsed = [Intent(**dct) for dct in intents]
31+
self.regexp_patterns = [
32+
RegexPatterns(
33+
id=intent.id,
34+
regexp_full_match=intent.regexp_full_match,
35+
regexp_partial_match=intent.regexp_partial_match,
36+
)
37+
for intent in intents_parsed
4138
]
39+
self._compile_regex_patterns()
4240

4341
def predict(self, utterances: list[str]) -> list[LabelType]:
44-
return [list(self._predict_single(ut)) for ut in utterances]
45-
46-
def _match(self, text: str, intent_record: RegexPatternsCompiled) -> bool:
47-
full_match = any(ptn.fullmatch(text) for ptn in intent_record["regexp_full_match"])
48-
if full_match:
49-
return True
50-
return any(ptn.match(text) for ptn in intent_record["regexp_partial_match"])
42+
return [self._predict_single(utterance)[0] for utterance in utterances]
43+
44+
def predict_with_metadata(
45+
self,
46+
utterances: list[str],
47+
) -> tuple[list[LabelType], list[dict[str, Any]] | None]:
48+
predictions, metadata = [], []
49+
for utterance in utterances:
50+
prediction, matches = self._predict_single(utterance)
51+
predictions.append(prediction)
52+
metadata.append(matches)
53+
return predictions, metadata
54+
55+
def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[str, list[str]]:
56+
full_matches = [
57+
pattern.pattern
58+
for pattern in intent_record["regexp_full_match"]
59+
if pattern.fullmatch(utterance) is not None
60+
]
61+
partial_matches = [
62+
pattern.pattern
63+
for pattern in intent_record["regexp_partial_match"]
64+
if pattern.search(utterance) is not None
65+
]
66+
return {"full_matches": full_matches, "partial_matches": partial_matches}
5167

52-
def _predict_single(self, utterance: str) -> set[int]:
68+
def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str]]]:
5369
# todo test this
54-
return {
55-
intent_record["id"]
56-
for intent_record in self.regexp_patterns_compiled
57-
if self._match(utterance, intent_record)
58-
}
70+
prediction = set()
71+
matches: dict[str, list[str]] = {"full_matches": [], "partial_matches": []}
72+
for intent_record in self.regexp_patterns_compiled:
73+
intent_matches = self._match(utterance, intent_record)
74+
if intent_matches["full_matches"] or intent_matches["partial_matches"]:
75+
prediction.add(intent_record["id"])
76+
matches["full_matches"].extend(intent_matches["full_matches"])
77+
matches["partial_matches"].extend(intent_matches["partial_matches"])
78+
return list(prediction), matches
5979

6080
def score(self, context: Context, metric_fn: RegexpMetricFn) -> float:
6181
# TODO add parameter to a whole pipeline (or just to regexp module):
@@ -78,7 +98,29 @@ def get_assets(self) -> Artifact:
7898
return Artifact()
7999

80100
def load(self, path: str) -> None:
81-
pass
101+
dump_dir = Path(path)
102+
103+
with (dump_dir / self.metadata_dict_name).open() as file:
104+
self.regexp_patterns = json.load(file)
105+
106+
self._compile_regex_patterns()
82107

83108
def dump(self, path: str) -> None:
84-
pass
109+
dump_dir = Path(path)
110+
111+
with (dump_dir / self.metadata_dict_name).open("w") as file:
112+
json.dump(self.regexp_patterns, file, indent=4)
113+
114+
def _compile_regex_patterns(self) -> None:
115+
self.regexp_patterns_compiled = [
116+
RegexPatternsCompiled(
117+
id=regexp_patterns["id"],
118+
regexp_full_match=[
119+
re.compile(pattern, flags=re.IGNORECASE) for pattern in regexp_patterns["regexp_full_match"]
120+
],
121+
regexp_partial_match=[
122+
re.compile(ptn, flags=re.IGNORECASE) for ptn in regexp_patterns["regexp_partial_match"]
123+
],
124+
)
125+
for regexp_patterns in self.regexp_patterns
126+
]

autointent/modules/retrieval/vectordb.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def __init__(
3232
batch_size: int = 32,
3333
max_length: int | None = None,
3434
) -> None:
35-
if db_dir is None:
36-
db_dir = str(get_db_dir())
3735
self.embedder_name = embedder_name
3836
self.device = device
39-
self.db_dir = db_dir
37+
self._db_dir = db_dir
4038
self.batch_size = batch_size
4139
self.max_length = max_length
4240

@@ -58,6 +56,12 @@ def from_context(
5856
max_length=context.get_max_length(),
5957
)
6058

59+
@property
60+
def db_dir(self) -> str:
61+
if self._db_dir is None:
62+
self._db_dir = str(get_db_dir())
63+
return self._db_dir
64+
6165
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
6266
vector_index_client = VectorIndexClient(
6367
self.device, self.db_dir, embedder_batch_size=self.batch_size, embedder_max_length=self.max_length

autointent/modules/scoring/description/description.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from autointent.context import Context
1212
from autointent.context.embedder import Embedder
1313
from autointent.context.vector_index_client import VectorIndex, VectorIndexClient
14-
from autointent.context.vector_index_client.cache import get_db_dir
1514
from autointent.custom_types import LabelType
1615
from autointent.modules.scoring.base import ScoringModule
1716

@@ -30,22 +29,19 @@ class DescriptionScorer(ScoringModule):
3029
precomputed_embeddings: bool = False
3130
embedding_model_subdir: str = "embedding_model"
3231
_vector_index: VectorIndex
32+
db_dir: str
3333
name = "description"
3434

3535
def __init__(
3636
self,
3737
embedder_name: str,
38-
db_dir: Path | None = None,
3938
temperature: float = 1.0,
4039
device: str = "cpu",
4140
batch_size: int = 32,
4241
max_length: int | None = None,
4342
) -> None:
44-
if db_dir is None:
45-
db_dir = get_db_dir()
4643
self.temperature = temperature
4744
self.device = device
48-
self.db_dir = db_dir
4945
self.embedder_name = embedder_name
5046
self.batch_size = batch_size
5147
self.max_length = max_length
@@ -66,10 +62,10 @@ def from_context(
6662
instance = cls(
6763
temperature=temperature,
6864
device=context.get_device(),
69-
db_dir=context.get_db_dir(),
7065
embedder_name=embedder_name,
7166
)
7267
instance.precomputed_embeddings = precomputed_embeddings
68+
instance.db_dir = str(context.get_db_dir())
7369
return instance
7470

7571
def get_embedder_name(self) -> str:

autointent/modules/scoring/dnnc/dnnc.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,21 @@ def __init__(
5252
batch_size: int = 32,
5353
max_length: int | None = None,
5454
) -> None:
55-
if db_dir is None:
56-
db_dir = str(get_db_dir())
57-
5855
self.cross_encoder_name = cross_encoder_name
5956
self.embedder_name = embedder_name
6057
self.k = k
6158
self.train_head = train_head
6259
self.device = device
63-
self.db_dir = db_dir
60+
self._db_dir = db_dir
6461
self.batch_size = batch_size
6562
self.max_length = max_length
6663

64+
@property
65+
def db_dir(self) -> str:
66+
if self._db_dir is None:
67+
self._db_dir = str(get_db_dir())
68+
return self._db_dir
69+
6770
@classmethod
6871
def from_context(
6972
cls,
@@ -114,19 +117,18 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
114117
self.model = model
115118

116119
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
117-
"""
118-
Return
119-
---
120-
`(n_queries, n_classes)` matrix with zeros everywhere except the class of the best neighbor utterance
121-
"""
122-
labels, _, texts = self.vector_index.query(
123-
utterances,
124-
self.k,
125-
)
120+
return self._predict(utterances)[0]
126121

127-
cross_encoder_scores = self._get_cross_encoder_scores(utterances, texts)
128-
129-
return self._build_result(cross_encoder_scores, labels)
122+
def predict_with_metadata(
123+
self,
124+
utterances: list[str],
125+
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
126+
scores, neighbors, neighbors_scores = self._predict(utterances)
127+
metadata = [
128+
{"neighbors": utterance_neighbors, "scores": utterance_neighbors_scores}
129+
for utterance_neighbors, utterance_neighbors_scores in zip(neighbors, neighbors_scores, strict=True)
130+
]
131+
return scores, metadata
130132

131133
def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list[str]]) -> list[list[float]]:
132134
"""
@@ -215,6 +217,19 @@ def load(self, path: str) -> None:
215217
else:
216218
self.model = CrossEncoder(crossencoder_dir, device=self.device)
217219

220+
def _predict(
221+
self,
222+
utterances: list[str],
223+
) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]:
224+
labels, _, neigbors = self.vector_index.query(
225+
utterances,
226+
self.k,
227+
)
228+
229+
cross_encoder_scores = self._get_cross_encoder_scores(utterances, neigbors)
230+
231+
return self._build_result(cross_encoder_scores, labels), neigbors, cross_encoder_scores
232+
218233

219234
def build_result(scores: npt.NDArray[Any], labels: npt.NDArray[Any], n_classes: int) -> npt.NDArray[Any]:
220235
res = np.zeros((len(scores), n_classes))

autointent/modules/scoring/knn/knn.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,20 @@ def __init__(
4949
- closest: each sample has a non zero weight iff is the closest sample of some class
5050
- `device`: str, something like "cuda:0" or "cuda:0,1,2", a device to store embedding function
5151
"""
52-
if db_dir is None:
53-
db_dir = str(get_db_dir())
5452
self.embedder_name = embedder_name
5553
self.k = k
5654
self.weights = weights
57-
self.db_dir = db_dir
55+
self._db_dir = db_dir
5856
self.device = device
5957
self.batch_size = batch_size
6058
self.max_length = max_length
6159

60+
@property
61+
def db_dir(self) -> str:
62+
if self._db_dir is None:
63+
self._db_dir = str(get_db_dir())
64+
return self._db_dir
65+
6266
@classmethod
6367
def from_context(
6468
cls,
@@ -107,8 +111,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
107111
self._vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
108112

109113
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
110-
labels, distances, _ = self._vector_index.query(utterances, self.k)
111-
return apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
114+
return self._predict(utterances)[0]
115+
116+
def predict_with_metadata(
117+
self,
118+
utterances: list[str],
119+
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
120+
scores, neighbors = self._predict(utterances)
121+
metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors]
122+
return scores, metadata
112123

113124
def clear_cache(self) -> None:
114125
self._vector_index.clear_ram()
@@ -145,3 +156,8 @@ def load(self, path: str) -> None:
145156
embedder_max_length=self.metadata["max_length"],
146157
)
147158
self._vector_index = vector_index_client.get_index(self.embedder_name)
159+
160+
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
161+
labels, distances, neigbors = self._vector_index.query(utterances, self.k)
162+
scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
163+
return scores, neigbors

0 commit comments

Comments
 (0)