Skip to content

Commit a7bab75

Browse files
authored
Merge pull request #40 from ShvetsKS/update_with_inplace_predict
In-place predict for higgs
2 parents 8f8d0a4 + 977d3e7 commit a7bab75

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

configs/xgb_cpu_config.json

100755100644
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@
131131
"n-estimators": [1000],
132132
"objective": ["binary:logistic"],
133133
"tree-method": ["hist"],
134-
"enable-experimental-json-serialization": ["False"]
134+
"enable-experimental-json-serialization": ["False"],
135+
"inplace-predict": [""]
135136
},
136137
{
137138
"algorithm": "gbt",

configs/xgb_gpu_config.json

100755100644
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@
129129
"max-leaves": [256],
130130
"n-estimators": [1000],
131131
"objective": ["binary:logistic"],
132-
"tree-method": ["gpu_hist"]
132+
"tree-method": ["gpu_hist"],
133+
"inplace-predict": [""]
133134
},
134135
{
135136
"algorithm": "gbt",

xgboost/gbt.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def convert_xgb_predictions(y_pred, objective):
6767
help='Control a balance of positive and negative weights')
6868
parser.add_argument('--count-dmatrix', default=False, action='store_true',
6969
help='Count DMatrix creation in time measurements')
70+
parser.add_argument('--inplace-predict', default=False, action='store_true',
71+
help='Perform inplace_predict instead of default')
7072
parser.add_argument('--single-precision-histogram', default=False, action='store_true',
7173
help='Build histograms instead of double precision')
7274
parser.add_argument('--enable-experimental-json-serialization', default=True,
@@ -135,9 +137,13 @@ def fit():
135137
dtrain = xgb.DMatrix(X_train, y_train)
136138
return xgb.train(xgb_params, dtrain, params.n_estimators)
137139

138-
def predict():
139-
dtest = xgb.DMatrix(X_test, y_test)
140-
return booster.predict(dtest)
140+
if params.inplace_predict == False:
141+
def predict():
142+
dtest = xgb.DMatrix(X_test, y_test)
143+
return booster.predict(dtest)
144+
else:
145+
def predict():
146+
return booster.inplace_predict(np.ascontiguousarray(X_test.values, dtype=np.float32))
141147
else:
142148
def fit():
143149
return xgb.train(xgb_params, dtrain, params.n_estimators)
@@ -150,8 +156,7 @@ def predict():
150156
train_metric = metric_func(y_pred, y_train)
151157

152158
predict_time, y_pred = measure_function_time(predict, params=params)
153-
test_metric = metric_func(
154-
convert_xgb_predictions(y_pred, params.objective), y_test)
159+
test_metric = metric_func(convert_xgb_predictions(y_pred, params.objective), y_test)
155160

156161
print_output(library='xgboost', algorithm=f'gradient_boosted_trees_{task}',
157162
stages=['training', 'prediction'], columns=columns,

0 commit comments

Comments
 (0)