Skip to content

Commit 8da2275

Browse files
committed
fix typing
1 parent f2a916f commit 8da2275

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
self.loss_function = loss_function
107107
self.verbose = verbose
108108
self.catboost_kwargs = catboost_kwargs or {}
109+
self.catboost_kwargs |= self.get_extra_params()
109110

110111
@classmethod
111112
def from_context(
@@ -125,6 +126,7 @@ def from_context(
125126
loss_function=loss_function,
126127
verbose=verbose,
127128
features_type=features_type,
129+
use_embedding_features=use_embedding_features,
128130
**catboost_kwargs,
129131
)
130132

@@ -144,12 +146,26 @@ def _prepare_data_for_fit(
144146
if self.use_embedding_features:
145147
data = pd.DataFrame({"embedding": encoded_utterances})
146148
else:
147-
data = pd.DataFrame(encoded_utterances)
149+
data = pd.DataFrame(np.array(encoded_utterances))
148150
if self.features_type == FeaturesType.BOTH:
149151
data["text"] = utterances
150152
return data
151153
return pd.DataFrame({"text": utterances})
152154

155+
def get_extra_params(self) -> dict[str, Any]:
156+
extra_params = {}
157+
if self.features_type == FeaturesType.EMBEDDING:
158+
if self.use_embedding_features: # to not raise error if embedding witout embedding_features
159+
extra_params["embedding_features"] = ["embedding"]
160+
elif self.features_type in {FeaturesType.TEXT, FeaturesType.BOTH}:
161+
extra_params["text_features"] = ["text"]
162+
if self.features_type == FeaturesType.BOTH and self.use_embedding_features:
163+
extra_params["embedding_features"] = ["embedding"]
164+
else:
165+
msg = f"Unsupported features type: {self.features_type}"
166+
raise ValueError(msg)
167+
return extra_params
168+
153169
def fit(
154170
self,
155171
utterances: list[str],
@@ -167,19 +183,6 @@ def fit(
167183
else ("MultiClass" if self._n_classes > BINARY_CLASS_THRESHOLD else "Logloss")
168184
)
169185

170-
extra_params = {}
171-
if self.features_type == FeaturesType.EMBEDDING:
172-
if self.use_embedding_features: # to not raise error if embedding witout embedding_features
173-
extra_params["embedding_features"] = ["embedding"]
174-
elif self.features_type in {FeaturesType.TEXT, FeaturesType.BOTH}:
175-
extra_params["text_features"] = ["text"]
176-
if self.features_type == FeaturesType.BOTH and self.use_embedding_features:
177-
extra_params["embedding_features"] = ["embedding"]
178-
else:
179-
msg = f"Unsupported features type: {self.features_type}"
180-
raise ValueError(msg)
181-
self.catboost_kwargs.update(extra_params)
182-
183186
self._model = CatBoostClassifier(
184187
loss_function=self.loss_function or default_loss,
185188
verbose=self.verbose,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ typing = [
7070
"types-pygments (>=2.18.0.20240506,<3.0.0)",
7171
"types-setuptools (>=75.2.0.20241019,<76.0.0)",
7272
"joblib-stubs (>=1.4.2.5.20240918,<2.0.0)",
73+
"pandas-stubs (>= 2.2.3.250527, <3.0.0)",
7374
]
7475
docs = [
7576
"sphinx (>=8.1.3,<9.0.0)",

0 commit comments

Comments
 (0)