Skip to content

Commit cf325f1

Browse files
committed
update after merge
1 parent 4be65a2 commit cf325f1

File tree

3 files changed

+29
-57
lines changed

3 files changed

+29
-57
lines changed

autointent/_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
"""
7171
Initialize the Embedder.
7272
73-
:param model_name: Path to a local model directory or a Hugging Face model name.
73+
:param model_name_or_path: Path to a local model directory or a Hugging Face model name.
7474
:param device: Device to run the model on (e.g., "cpu", "cuda").
7575
:param batch_size: Batch size for embedding calculations.
7676
:param max_length: Maximum sequence length for the embedding model.

autointent/modules/__init__.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,23 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2525
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
2626

2727
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = _create_modules_dict(
28-
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer, RerankScorer, SklearnScorer,]
28+
[
29+
DNNCScorer,
30+
KNNScorer,
31+
LinearScorer,
32+
DescriptionScorer,
33+
RerankScorer,
34+
SklearnScorer,
35+
]
2936
)
3037

3138
SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = _create_modules_dict(
32-
[MLKnnScorer, LinearScorer, DescriptionScorer, SklearnScorer, ],
39+
[
40+
MLKnnScorer,
41+
LinearScorer,
42+
DescriptionScorer,
43+
SklearnScorer,
44+
],
3345
)
3446

3547
PREDICTION_MODULES_MULTICLASS: dict[str, type[DecisionModule]] = _create_modules_dict(
@@ -40,29 +52,4 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4052
[AdaptiveDecision, ThresholdDecision, TunableDecision],
4153
)
4254

43-
__all__ = [
44-
"PREDICTION_MODULES_MULTICLASS",
45-
"PREDICTION_MODULES_MULTILABEL",
46-
"RETRIEVAL_MODULES_MULTICLASS",
47-
"RETRIEVAL_MODULES_MULTILABEL",
48-
"SCORING_MODULES_MULTICLASS",
49-
"SCORING_MODULES_MULTILABEL",
50-
"AdaptivePredictor",
51-
"ArgmaxPredictor",
52-
"DNNCScorer",
53-
"DescriptionScorer",
54-
"JinoosPredictor",
55-
"KNNScorer",
56-
"LinearScorer",
57-
"MLKnnScorer",
58-
"Module",
59-
"PredictionModule",
60-
"RegExp",
61-
"RerankScorer",
62-
"RetrievalModule",
63-
"ScoringModule",
64-
"ThresholdPredictor",
65-
"TunablePredictor",
66-
"VectorDBModule",
67-
"SklearnScorer",
68-
]
55+
__all__ = [] # type: ignore[var-annotated]

autointent/modules/scoring/_sklearn/scorer.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import joblib
66
import numpy as np
77
import numpy.typing as npt
8+
from sklearn.linear_model import LogisticRegression
89
from sklearn.multioutput import MultiOutputClassifier
910
from sklearn.utils import all_estimators
1011
from typing_extensions import Self
1112

1213
from autointent import Context, Embedder
13-
from autointent.context.vector_index_client import VectorIndexClient
1414
from autointent.custom_types import BaseMetadataDict, LabelType
15-
from autointent.modules.scoring._base import ScoringModule
15+
from autointent.modules.abc import ScoringModule
1616

1717
AVAILIABLE_CLASSIFIERS = {name: class_ for name, class_ in all_estimators() if hasattr(class_, "predict_proba")}
1818

@@ -90,7 +90,7 @@ def __init__(
9090
def from_context(
9191
cls,
9292
context: Context,
93-
clf_name: str,
93+
clf_name: str = LogisticRegression.__name__,
9494
clf_args: dict[str, Any] | None = None,
9595
embedder_name: str | None = None,
9696
) -> Self:
@@ -105,10 +105,8 @@ def from_context(
105105
"""
106106
if embedder_name is None:
107107
embedder_name = context.optimization_info.get_best_embedder()
108-
precomputed_embeddings = True
109-
else:
110-
precomputed_embeddings = context.vector_index_client.exists(embedder_name)
111-
instance = cls(
108+
109+
return cls(
112110
embedder_name=embedder_name,
113111
device=context.get_device(),
114112
seed=context.seed,
@@ -117,9 +115,6 @@ def from_context(
117115
clf_name=clf_name,
118116
clf_args=clf_args,
119117
)
120-
instance.precomputed_embeddings = precomputed_embeddings
121-
instance.db_dir = str(context.get_db_dir())
122-
return instance
123118

124119
def fit(
125120
self,
@@ -135,23 +130,13 @@ def fit(
135130
"""
136131
self._multilabel = isinstance(labels[0], list)
137132

138-
if self.precomputed_embeddings:
139-
# this happens only when SklearnScorer is within Pipeline opimization after RetrievalNode optimization
140-
vector_index_client = VectorIndexClient(self.device, self.db_dir, self.batch_size, self.max_length)
141-
vector_index = vector_index_client.get_index(self.embedder_name)
142-
features = vector_index.get_all_embeddings()
143-
if len(features) != len(utterances):
144-
msg = "Vector index mismatches provided utterances"
145-
raise ValueError(msg)
146-
embedder = vector_index.embedder
147-
else:
148-
embedder = Embedder(
149-
device=self.device,
150-
model_name=self.embedder_name,
151-
batch_size=self.batch_size,
152-
max_length=self.max_length,
153-
)
154-
features = embedder.embed(utterances)
133+
embedder = Embedder(
134+
device=self.device,
135+
model_name_or_path=self.embedder_name,
136+
batch_size=self.batch_size,
137+
max_length=self.max_length,
138+
)
139+
features = embedder.embed(utterances)
155140
self.clf_args = {} if self.clf_args is None else self.clf_args
156141
if AVAILIABLE_CLASSIFIERS.get(self.clf_name):
157142
base_clf = AVAILIABLE_CLASSIFIERS[self.clf_name](**self.clf_args)
@@ -229,7 +214,7 @@ def load(self, path: str) -> None:
229214
embedder_dir = dump_dir / self.embedding_model_subdir
230215
self._embedder = Embedder(
231216
device=self.device,
232-
model_name=embedder_dir,
217+
model_name_or_path=embedder_dir,
233218
batch_size=metadata["batch_size"],
234219
max_length=metadata["max_length"],
235220
)

0 commit comments

Comments
 (0)