Skip to content

Commit cdb80dd

Browse files
authored
Update scorer.py
1 parent bbeef5d commit cdb80dd

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

autointent/modules/scoring/_sklearn/scorer.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from sklearn.utils import all_estimators
1010
from typing_extensions import Self
1111

12-
from autointent.context import Context
13-
from autointent.context.embedder import Embedder
12+
from autointent.context import Context, Embedder
1413
from autointent.context.vector_index_client import VectorIndexClient
1514
from autointent.custom_types import BaseMetadataDict, LabelType
1615
from autointent.modules.scoring._base import ScoringModule
@@ -109,21 +108,17 @@ def from_context(
109108
precomputed_embeddings = True
110109
else:
111110
precomputed_embeddings = context.vector_index_client.exists(embedder_name)
112-
context.device = context.get_device()
113-
context.embedder_batch_size = context.get_batch_size()
114-
context.embedder_max_length = context.get_max_length()
115-
context.db_dir = context.get_db_dir()
116111
instance = cls(
117-
model_name=embedder_name,
118-
device=context.device,
112+
embedder_name=embedder_name,
113+
device=context.get_device(),
119114
seed=context.seed,
120-
batch_size=context.embedder_batch_size,
121-
max_length=context.embedder_max_length,
115+
batch_size=context.get_batch_size(),
116+
max_length=context.get_max_length(),
122117
clf_name=clf_name,
123118
clf_args=clf_args,
124119
)
125120
instance.precomputed_embeddings = precomputed_embeddings
126-
instance.db_dir = str(context.db_dir)
121+
instance.db_dir = str(context.get_db_dir())
127122
return instance
128123

129124
def fit(

0 commit comments

Comments
 (0)