Skip to content

Commit f88efbb

Browse files
authored
Update scorer.py
1 parent 289daf5 commit f88efbb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

autointent/modules/scoring/_sklearn/scorer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
embedder_name: str,
5757
clf_name: str,
5858
cv: int = 3,
59-
clf_args: dict[str, Any] = {}, # noqa: B006
59+
clf_args: dict[str, Any] | None = None,
6060
n_jobs: int = -1,
6161
device: str = "cpu",
6262
seed: int = 0,
@@ -91,7 +91,7 @@ def from_context(
9191
cls,
9292
context: Context,
9393
clf_name: str,
94-
clf_args: dict[str, Any] = {}, # noqa: B006
94+
clf_args: dict[str, Any] | None = None,
9595
embedder_name: str | None = None,
9696
) -> Self:
9797
"""
@@ -136,7 +136,7 @@ def fit(
136136
self._multilabel = isinstance(labels[0], list)
137137

138138
if self.precomputed_embeddings:
139-
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
139+
# this happens only when SklearnScorer is within Pipeline opimization after RetrievalNode optimization
140140
vector_index_client = VectorIndexClient(self.device, self.db_dir, self.batch_size, self.max_length)
141141
vector_index = vector_index_client.get_index(self.embedder_name)
142142
features = vector_index.get_all_embeddings()
@@ -152,7 +152,7 @@ def fit(
152152
max_length=self.max_length,
153153
)
154154
features = embedder.embed(utterances)
155-
155+
self.clf_args = {} if self.clf_args is None else self.clf_args
156156
if AVAILIABLE_CLASSIFIERS.get(self.clf_name):
157157
base_clf = AVAILIABLE_CLASSIFIERS[self.clf_name](**self.clf_args)
158158
else:

0 commit comments

Comments
 (0)