Skip to content

Commit bcb9e14

Browse files
committed
fixing modelbuilder benchmark for covtype dataset (#125)
1 parent 1d557d8 commit bcb9e14

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

modelbuilders_bench/lgbm_mb.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@
9999
if 'cudf' in str(type(y_train)):
100100
params.n_classes = y_train[y_train.columns[0]].nunique()
101101
else:
102-
params.n_classes = len(np.unique(y_train))
102+
unique_y_train = np.unique(y_train)
103+
params.n_classes = len(unique_y_train)
104+
if max(unique_y_train) != len(unique_y_train) - 1:
105+
params.n_classes = int(max(unique_y_train)) + 1
106+
103107
if params.n_classes > 2:
104108
lgbm_params['num_class'] = params.n_classes
105109

modelbuilders_bench/xgb_mb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def convert_xgb_predictions(y_pred, objective):
3030
if objective == 'multi:softprob':
3131
y_pred = convert_probs_to_classes(y_pred)
3232
elif objective == 'binary:logistic':
33-
y_pred = y_pred.astype(np.int32)
33+
y_pred = (y_pred >= 0.5).astype(np.int32)
3434
return y_pred
3535

3636

0 commit comments

Comments
 (0)