@@ -18,6 +18,7 @@ def test_catboost_scorer_dump_load(dataset):
1818 data_handler = DataHandler (dataset )
1919
2020 scorer_original = CatBoostScorer (
21+ embedder_config = get_test_embedder_config (),
2122 iterations = 50 ,
2223 learning_rate = 0.05 ,
2324 depth = 6 ,
@@ -83,11 +84,11 @@ def test_catboost_prediction_multilabel(dataset):
8384 predictions ,
8485 np .array (
8586 [
86- [0.41777172 , 0.5278134 , 0.41807876 , 0.4174544 ],
87- [0.40775846 , 0.46434019 , 0.42728555 , 0.43836945 ],
88- [0.4207232 , 0.49201536 , 0.42798494 , 0.41541217 ],
89- [0.46765036 , 0.45065999 , 0.49705517 , 0.45052473 ],
90- [0.41694272 , 0.54160408 , 0.40944069 , 0.41674984 ],
87+ [0.37150982 , 0.5935175 , 0.36279131 , 0.37357718 ],
88+ [0.37309364 , 0.53746911 , 0.38326219 , 0.39884488 ],
89+ [0.37744044 , 0.56529594 , 0.37456834 , 0.38646843 ],
90+ [0.41484185 , 0.48539558 , 0.41669755 , 0.42929345 ],
91+ [0.38344306 , 0.58516115 , 0.37940454 , 0.39640789 ],
9192 ]
9293 ),
9394 rtol = 0.01 ,
@@ -133,6 +134,7 @@ def test_catboost_cache_clearing(dataset):
133134 """Test that the transformer model properly handles cache clearing."""
134135 data_handler = DataHandler (dataset )
135136 scorer = CatBoostScorer (
137+ embedder_config = get_test_embedder_config (),
136138 iterations = 50 ,
137139 learning_rate = 0.05 ,
138140 depth = 6 ,
0 commit comments