Skip to content

Commit 8ee94d9

Browse files
author
igor_rukhovich
committed
Changed output style
1 parent f0aa477 commit 8ee94d9

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

modelbuilders/lgbm_mb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
lgbm.train, lgbm_params, lgbm_train, params=params, num_boost_round=params.n_estimators,
109109
valid_sets=lgbm_train, verbose_eval=False)
110110
train_metric = None
111-
if X_train != X_test:
111+
if not X_train.equals(X_test):
112112
y_train_pred = model_lgbm.predict(X_train)
113113
train_metric = metric_func(y_train, y_train_pred)
114114

@@ -132,7 +132,7 @@
132132

133133
print_output(
134134
library='modelbuilders', algorithm=f'lightgbm_{task}_and_modelbuilder',
135-
stages=['lgbm_train', 'lgbm_predict', 'daal_predict'],
135+
stages=['lgbm_train', 'lgbm_predict', 'daal4py_predict'],
136136
columns=columns, params=params,
137137
functions=['lgbm_dataset', 'lgbm_dataset', 'lgbm_train', 'lgbm_predict', 'lgbm_to_daal',
138138
'daal_compute'],

modelbuilders/utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,25 @@ def print_output(library, algorithm, stages, columns, params, functions,
3737
accuracy=accuracies[i])
3838
elif params.output_format == 'json':
3939
output = []
40+
output.append({
41+
'library': library,
42+
'algorithm': algorithm,
43+
'input_data': {
44+
'data_format': params.data_format,
45+
'data_order': params.data_order,
46+
'data_type': str(params.dtype),
47+
'dataset_name': params.dataset_name,
48+
'rows': data[0].shape[0],
49+
'columns': data[0].shape[1]
50+
}
51+
})
52+
if hasattr(params, 'n_classes'):
53+
output[-1]['input_data'].update({'classes': params.n_classes})
4054
for i in range(len(stages)):
4155
result = {
42-
'library': library,
43-
'algorithm': algorithm,
4456
'stage': stages[i],
45-
'input_data': {
46-
'data_format': params.data_format,
47-
'data_order': params.data_order,
48-
'data_type': str(params.dtype),
49-
'dataset_name': params.dataset_name,
50-
'rows': data[i].shape[0],
51-
'columns': data[i].shape[1]
52-
}
5357
}
54-
if stages[i] == 'daal4py_predict':
58+
if 'daal' in stages[i]:
5559
result.update({'conversion_to_daal4py': times[2 * i],
5660
'prediction_time': times[2 * i + 1]})
5761
elif 'train' in stages[i]:
@@ -62,7 +66,5 @@ def print_output(library, algorithm, stages, columns, params, functions,
6266
'prediction_time': times[2 * i + 1]})
6367
if accuracies[i] is not None:
6468
result.update({f'{accuracy_type}': accuracies[i]})
65-
if hasattr(params, 'n_classes'):
66-
result['input_data'].update({'classes': params.n_classes})
6769
output.append(result)
6870
print(json.dumps(output, indent=4))

modelbuilders/xgb_mb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def predict():
135135
t_train, model_xgb = measure_function_time(
136136
fit, None if params.count_dmatrix else dtrain, params=params)
137137
train_metric = None
138-
if X_train != X_test:
138+
if not X_train.equals(X_test):
139139
y_train_pred = model_xgb.predict(dtrain)
140140
train_metric = metric_func(y_train, y_train_pred)
141141

0 commit comments

Comments
 (0)