Skip to content

Commit 80bb7c1

Browse files
committed
Fix namings in knn_clsf
1 parent f44073a commit 80bb7c1

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

cuml/knn_clsf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737

3838
# Measure time and accuracy on fitting
3939
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, params=params)
40-
if args.task == 'classification':
40+
if params.task == 'classification':
4141
y_pred = knn_clsf.predict(X_train)
4242
train_acc = 100 * accuracy_score(y_pred, y_train)
4343

4444
# Measure time and accuracy on prediction
45-
if args.task == 'classification':
45+
if params.task == 'classification':
4646
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
4747
test_acc = 100 * accuracy_score(yp, y_test)
4848
else:
@@ -51,7 +51,7 @@
5151
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
5252
'n_neighbors', 'n_classes', 'time')
5353

54-
if args.task == 'classification':
54+
if params.task == 'classification':
5555
print_output(library='cuml', algorithm='knn_classification',
5656
stages=['training', 'prediction'], columns=columns, params=params,
5757
functions=['knn_clsf.fit', 'knn_clsf.predict'],

sklearn/knn_clsf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838

3939
# Measure time and accuracy on fitting
4040
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, params=params)
41-
if args.task == 'classification':
41+
if params.task == 'classification':
4242
y_pred = knn_clsf.predict(X_train)
4343
train_acc = 100 * accuracy_score(y_pred, y_train)
4444

4545
# Measure time and accuracy on prediction
46-
if args.task == 'classification':
46+
if params.task == 'classification':
4747
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
4848
test_acc = 100 * accuracy_score(yp, y_test)
4949
else:
@@ -52,7 +52,7 @@
5252
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
5353
'n_neighbors', 'n_classes', 'time')
5454

55-
if args.task == 'classification':
55+
if params.task == 'classification':
5656
print_output(library='sklearn', algorithm='knn_classification',
5757
stages=['training', 'prediction'], columns=columns, params=params,
5858
functions=['knn_clsf.fit', 'knn_clsf.predict'],

0 commit comments

Comments
 (0)