Skip to content

Commit 0792964

Browse files
committed
Add 'count-dmatrix' option in XGB benchmark
1 parent c7e0abd commit 0792964

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

xgboost/gbt.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def convert_xgb_predictions(y_pred, objective):
6565
choices=('reg:squarederror', 'binary:logistic',
6666
'multi:softmax', 'multi:softprob'),
6767
help='Control a balance of positive and negative weights')
68+
parser.add_argument('--count-dmatrix', default=False, action='store_true',
69+
help='Count DMatrix creation in time measurements')
6870

6971
params = parse_args(parser)
7072

@@ -122,14 +124,26 @@ def convert_xgb_predictions(y_pred, objective):
122124

123125
dtrain = xgb.DMatrix(X_train, y_train)
124126
dtest = xgb.DMatrix(X_test, y_test)
127+
if params.count_dmatrix:
128+
def fit():
129+
dtrain = xgb.DMatrix(X_train, y_train)
130+
return xgb.train(xgb_params, dtrain, params.n_estimators)
131+
132+
def predict():
133+
dtest = xgb.DMatrix(X_test, y_test)
134+
return booster.predict(dtest)
135+
else:
136+
def fit():
137+
return xgb.train(xgb_params, dtrain, params.n_estimators)
138+
139+
def predict():
140+
return booster.predict(dtest)
125141

126-
fit_time, booster = measure_function_time(
127-
xgb.train, xgb_params, dtrain, params.n_estimators, params=params)
142+
fit_time, booster = measure_function_time(fit, params=params)
128143
y_pred = convert_xgb_predictions(booster.predict(dtrain), params.objective)
129144
train_metric = metric_func(y_pred, y_train)
130145

131-
predict_time, y_pred = measure_function_time(
132-
booster.predict, dtest, params=params)
146+
predict_time, y_pred = measure_function_time(predict, params=params)
133147
test_metric = metric_func(
134148
convert_xgb_predictions(y_pred, params.objective), y_test)
135149

0 commit comments

Comments
 (0)