File tree Expand file tree Collapse file tree 2 files changed +68
-2
lines changed
autointent/modules/scoring/_catboost Expand file tree Collapse file tree 2 files changed +68
-2
lines changed Original file line number Diff line number Diff line change @@ -209,8 +209,8 @@ def fit(
209209 if self ._multilabel :
210210 y_mat = np .zeros ((len (labels ), self ._n_classes ), dtype = np .float32 )
211211 for i , lbls in enumerate (cast ("Sequence[Sequence[int]]" , labels )):
212- for lbl in lbls :
213- y_mat [i , lbl ] = 1.0
212+ for class_i , lbl in enumerate ( lbls ) :
213+ y_mat [i , class_i ] = lbl
214214 y = y_mat
215215 else :
216216 y = np .asarray (cast ("Sequence[int]" , labels ), dtype = np .int64 )
Original file line number Diff line number Diff line change @@ -90,6 +90,72 @@ def test_catboost_prediction(dataset):
9090 assert metadata is None
9191
9292
93+ def test_catboost_prediction_multilabel (dataset ):
94+ """Test that the transformer model can fit and make predictions."""
95+ data_handler = DataHandler (dataset .to_multilabel ())
96+
97+ scorer = CatBoostScorer (
98+ classification_model_config = "prajjwal1/bert-tiny" ,
99+ iterations = 50 ,
100+ learning_rate = 0.05 ,
101+ depth = 6 ,
102+ l2_leaf_reg = 3 ,
103+ eval_metric = "Accuracy" ,
104+ random_seed = 42 ,
105+ verbose = False ,
106+ )
107+
108+ scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
109+
110+ test_data = [
111+ "why is there a hold on my american saving bank account" ,
112+ "i am nost sure why my account is blocked" ,
113+ "why is there a hold on my capital one checking account" ,
114+ "i think my account is blocked but i do not know the reason" ,
115+ "can you tell me why is my bank account frozen" ,
116+ ]
117+
118+ predictions = scorer .predict (test_data )
119+ assert np .allclose (
120+ predictions ,
121+ np .array (
122+ [
123+ [
124+ 0.22828311 ,
125+ 0.70298906 ,
126+ 0.24396814 ,
127+ 0.2318292 ,
128+ ],
129+ [
130+ 0.21511787 ,
131+ 0.43272557 ,
132+ 0.28723239 ,
133+ 0.40194354 ,
134+ ],
135+ [
136+ 0.24727756 ,
137+ 0.65392399 ,
138+ 0.22263033 ,
139+ 0.27726414 ,
140+ ],
141+ [
142+ 0.26847769 ,
143+ 0.39022974 ,
144+ 0.28379654 ,
145+ 0.4868582 ,
146+ ],
147+ [
148+ 0.11476477 ,
149+ 0.86928679 ,
150+ 0.11779149 ,
151+ 0.12179479 ,
152+ ],
153+ ]
154+ ),
155+ 1e-2 ,
156+ )
157+
158+
93159def test_catboost_without_embedder (dataset ):
94160 """Test that CatBoostScorer works properly without an embedder (using BoW encoding)."""
95161 data_handler = DataHandler (dataset )
You can’t perform that action at this time.
0 commit comments