55import joblib
66import numpy as np
77import numpy .typing as npt
8+ from sklearn .linear_model import LogisticRegression
89from sklearn .multioutput import MultiOutputClassifier
910from sklearn .utils import all_estimators
1011from typing_extensions import Self
1112
1213from autointent import Context , Embedder
13- from autointent .context .vector_index_client import VectorIndexClient
1414from autointent .custom_types import BaseMetadataDict , LabelType
15- from autointent .modules .scoring . _base import ScoringModule
15+ from autointent .modules .abc import ScoringModule
1616
1717AVAILIABLE_CLASSIFIERS = {name : class_ for name , class_ in all_estimators () if hasattr (class_ , "predict_proba" )}
1818
@@ -90,7 +90,7 @@ def __init__(
9090 def from_context (
9191 cls ,
9292 context : Context ,
93- clf_name : str ,
93+ clf_name : str = LogisticRegression . __name__ ,
9494 clf_args : dict [str , Any ] | None = None ,
9595 embedder_name : str | None = None ,
9696 ) -> Self :
@@ -105,10 +105,8 @@ def from_context(
105105 """
106106 if embedder_name is None :
107107 embedder_name = context .optimization_info .get_best_embedder ()
108- precomputed_embeddings = True
109- else :
110- precomputed_embeddings = context .vector_index_client .exists (embedder_name )
111- instance = cls (
108+
109+ return cls (
112110 embedder_name = embedder_name ,
113111 device = context .get_device (),
114112 seed = context .seed ,
@@ -117,9 +115,6 @@ def from_context(
117115 clf_name = clf_name ,
118116 clf_args = clf_args ,
119117 )
120- instance .precomputed_embeddings = precomputed_embeddings
121- instance .db_dir = str (context .get_db_dir ())
122- return instance
123118
124119 def fit (
125120 self ,
@@ -135,23 +130,13 @@ def fit(
135130 """
136131 self ._multilabel = isinstance (labels [0 ], list )
137132
138- if self .precomputed_embeddings :
139- # this happens only when SklearnScorer is within Pipeline opimization after RetrievalNode optimization
140- vector_index_client = VectorIndexClient (self .device , self .db_dir , self .batch_size , self .max_length )
141- vector_index = vector_index_client .get_index (self .embedder_name )
142- features = vector_index .get_all_embeddings ()
143- if len (features ) != len (utterances ):
144- msg = "Vector index mismatches provided utterances"
145- raise ValueError (msg )
146- embedder = vector_index .embedder
147- else :
148- embedder = Embedder (
149- device = self .device ,
150- model_name = self .embedder_name ,
151- batch_size = self .batch_size ,
152- max_length = self .max_length ,
153- )
154- features = embedder .embed (utterances )
133+ embedder = Embedder (
134+ device = self .device ,
135+ model_name_or_path = self .embedder_name ,
136+ batch_size = self .batch_size ,
137+ max_length = self .max_length ,
138+ )
139+ features = embedder .embed (utterances )
155140 self .clf_args = {} if self .clf_args is None else self .clf_args
156141 if AVAILIABLE_CLASSIFIERS .get (self .clf_name ):
157142 base_clf = AVAILIABLE_CLASSIFIERS [self .clf_name ](** self .clf_args )
@@ -229,7 +214,7 @@ def load(self, path: str) -> None:
229214 embedder_dir = dump_dir / self .embedding_model_subdir
230215 self ._embedder = Embedder (
231216 device = self .device ,
232- model_name = embedder_dir ,
217+ model_name_or_path = embedder_dir ,
233218 batch_size = metadata ["batch_size" ],
234219 max_length = metadata ["max_length" ],
235220 )
0 commit comments