Skip to content

Commit b011d98

Browse files
committed
fix tests
1 parent 0e19470 commit b011d98

File tree

3 files changed

+12
-36
lines changed

3 files changed

+12
-36
lines changed

autointent/_dump_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def load( # noqa: C901, PLR0912, PLR0915
254254
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
255255
logger.exception(msg)
256256
elif child.name == Dumper.catboost_models:
257-
cat_model = CatBoostClassifier(str(path / Dumper.catboost_models / "model.cbm"))
258-
cat_model.load_model()
257+
cat_model = CatBoostClassifier()
258+
cat_model.load_model(str(path / Dumper.catboost_models / "model.cbm"))
259259
else:
260260
msg = f"Found unexpected child {child}"
261261
logger.error(msg)

autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""CatBoostScorer class for CatBoost-based classification with switchable encoding."""
2+
23
import logging
34
from enum import StrEnum
45
from typing import Any, cast
@@ -60,7 +61,7 @@ class CatBoostScorer(BaseScorer):
6061
eval_metric="Accuracy",
6162
random_seed=42,
6263
verbose=False,
63-
features_type="text", # or "embedding" or "both"
64+
features_type="embedding", # or "text" or "both"
6465
)
6566
utterances = ["hello", "goodbye", "allo", "sayonara"]
6667
labels = [0, 1, 0, 1]
@@ -71,8 +72,8 @@ class CatBoostScorer(BaseScorer):
7172
7273
.. testoutput::
7374
74-
[[0.50525691 0.49474309]
75-
[0.50525691 0.49474309]]
75+
[[0.41493207 0.58506793]
76+
[0.55036046 0.44963954]]
7677
7778
"""
7879

tests/modules/scoring/test_catboost.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,39 +78,14 @@ def test_catboost_prediction_multilabel(dataset):
7878
predictions,
7979
np.array(
8080
[
81-
[
82-
0.2,
83-
0.7,
84-
0.2,
85-
0.2,
86-
],
87-
[
88-
0.2,
89-
0.4,
90-
0.2,
91-
0.4,
92-
],
93-
[
94-
0.2,
95-
0.6,
96-
0.2,
97-
0.2,
98-
],
99-
[
100-
0.2,
101-
0.3,
102-
0.2,
103-
0.4,
104-
],
105-
[
106-
0.1,
107-
0.8,
108-
0.1,
109-
0.1,
110-
],
81+
[0.41777172, 0.5278134, 0.41807876, 0.4174544],
82+
[0.40775846, 0.46434019, 0.42728555, 0.43836945],
83+
[0.4207232, 0.49201536, 0.42798494, 0.41541217],
84+
[0.46765036, 0.45065999, 0.49705517, 0.45052473],
85+
[0.41694272, 0.54160408, 0.40944069, 0.41674984],
11186
]
11287
),
113-
0.1,
88+
rtol=0.01,
11489
)
11590

11691

0 commit comments

Comments
 (0)