Skip to content

Commit 0cf1e74

Browse files
committed
upd assertions about catboost predictions
1 parent 31fa6e8 commit 0cf1e74

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/modules/scoring/test_catboost.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_catboost_scorer_dump_load(dataset):
1818
data_handler = DataHandler(dataset)
1919

2020
scorer_original = CatBoostScorer(
21+
embedder_config=get_test_embedder_config(),
2122
iterations=50,
2223
learning_rate=0.05,
2324
depth=6,
@@ -83,11 +84,11 @@ def test_catboost_prediction_multilabel(dataset):
8384
predictions,
8485
np.array(
8586
[
86-
[0.41777172, 0.5278134, 0.41807876, 0.4174544],
87-
[0.40775846, 0.46434019, 0.42728555, 0.43836945],
88-
[0.4207232, 0.49201536, 0.42798494, 0.41541217],
89-
[0.46765036, 0.45065999, 0.49705517, 0.45052473],
90-
[0.41694272, 0.54160408, 0.40944069, 0.41674984],
87+
[0.37150982, 0.5935175, 0.36279131, 0.37357718],
88+
[0.37309364, 0.53746911, 0.38326219, 0.39884488],
89+
[0.37744044, 0.56529594, 0.37456834, 0.38646843],
90+
[0.41484185, 0.48539558, 0.41669755, 0.42929345],
91+
[0.38344306, 0.58516115, 0.37940454, 0.39640789],
9192
]
9293
),
9394
rtol=0.01,
@@ -133,6 +134,7 @@ def test_catboost_cache_clearing(dataset):
133134
"""Test that the transformer model properly handles cache clearing."""
134135
data_handler = DataHandler(dataset)
135136
scorer = CatBoostScorer(
137+
embedder_config=get_test_embedder_config(),
136138
iterations=50,
137139
learning_rate=0.05,
138140
depth=6,

0 commit comments

Comments
 (0)