Skip to content

Commit 1ccb642

Browse files
committed
fix multilabel prediction
1 parent aa45b90 commit 1ccb642

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def fit(
209209
if self._multilabel:
210210
y_mat = np.zeros((len(labels), self._n_classes), dtype=np.float32)
211211
for i, lbls in enumerate(cast("Sequence[Sequence[int]]", labels)):
212-
for lbl in lbls:
213-
y_mat[i, lbl] = 1.0
212+
for class_i, lbl in enumerate(lbls):
213+
y_mat[i, class_i] = lbl
214214
y = y_mat
215215
else:
216216
y = np.asarray(cast("Sequence[int]", labels), dtype=np.int64)

tests/modules/scoring/test_catboost.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,72 @@ def test_catboost_prediction(dataset):
9090
assert metadata is None
9191

9292

93+
def test_catboost_prediction_multilabel(dataset):
94+
"""Test that the transformer model can fit and make predictions."""
95+
data_handler = DataHandler(dataset.to_multilabel())
96+
97+
scorer = CatBoostScorer(
98+
classification_model_config="prajjwal1/bert-tiny",
99+
iterations=50,
100+
learning_rate=0.05,
101+
depth=6,
102+
l2_leaf_reg=3,
103+
eval_metric="Accuracy",
104+
random_seed=42,
105+
verbose=False,
106+
)
107+
108+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
109+
110+
test_data = [
111+
"why is there a hold on my american saving bank account",
112+
"i am nost sure why my account is blocked",
113+
"why is there a hold on my capital one checking account",
114+
"i think my account is blocked but i do not know the reason",
115+
"can you tell me why is my bank account frozen",
116+
]
117+
118+
predictions = scorer.predict(test_data)
119+
assert np.allclose(
120+
predictions,
121+
np.array(
122+
[
123+
[
124+
0.22828311,
125+
0.70298906,
126+
0.24396814,
127+
0.2318292,
128+
],
129+
[
130+
0.21511787,
131+
0.43272557,
132+
0.28723239,
133+
0.40194354,
134+
],
135+
[
136+
0.24727756,
137+
0.65392399,
138+
0.22263033,
139+
0.27726414,
140+
],
141+
[
142+
0.26847769,
143+
0.39022974,
144+
0.28379654,
145+
0.4868582,
146+
],
147+
[
148+
0.11476477,
149+
0.86928679,
150+
0.11779149,
151+
0.12179479,
152+
],
153+
]
154+
),
155+
1e-2,
156+
)
157+
158+
93159
def test_catboost_without_embedder(dataset):
94160
"""Test that CatBoostScorer works properly without an embedder (using BoW encoding)."""
95161
data_handler = DataHandler(dataset)

0 commit comments

Comments
 (0)