File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change 99
99
if 'cudf' in str (type (y_train )):
100
100
params .n_classes = y_train [y_train .columns [0 ]].nunique ()
101
101
else :
102
- params .n_classes = len (np .unique (y_train ))
102
+ unique_y_train = np .unique (y_train )
103
+ params .n_classes = len (unique_y_train )
104
+ if max (unique_y_train ) != len (unique_y_train ) - 1 :
105
+ params .n_classes = int (max (unique_y_train )) + 1
106
+
103
107
if params .n_classes > 2 :
104
108
lgbm_params ['num_class' ] = params .n_classes
105
109
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def convert_xgb_predictions(y_pred, objective):
30
30
if objective == 'multi:softprob' :
31
31
y_pred = convert_probs_to_classes (y_pred )
32
32
elif objective == 'binary:logistic' :
33
- y_pred = y_pred .astype (np .int32 )
33
+ y_pred = ( y_pred >= 0.5 ) .astype (np .int32 )
34
34
return y_pred
35
35
36
36
You can’t perform that action at this time.
0 commit comments