Skip to content

Commit b1f19c7

Browse files
committed
kNN classification update
1 parent 002fa9d commit b1f19c7

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

cuml/knn_clsf.py

100644100755
Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,22 @@
3232
algorithm=params.method,
3333
metric=params.metric)
3434

35-
knn_clsf.fit(X_train, y_train)
36-
# Time predict
37-
time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
35+
# Measure time and accuracy on fitting
36+
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train,
37+
params=params)
38+
y_pred = knn_clsf.predict(X_train)
39+
train_acc = 100 * accuracy_score(y_pred, y_train)
3840

39-
acc = 100 * accuracy_score(yp, y_test)
41+
# Measure time and accuracy on prediction
42+
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
43+
test_acc = 100 * accuracy_score(yp, y_test)
4044

4145
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
4246
'n_neighbors', 'n_classes', 'time')
4347

4448
print_output(library='cuml', algorithm='knn_classification',
45-
stages=['prediction'], columns=columns, params=params,
46-
functions=['knn_clsf.predict'], times=[time], accuracies=[acc],
47-
accuracy_type='accuracy[%]', data=[X_test], alg_instance=knn_clsf)
49+
stages=['training', 'prediction'], columns=columns, params=params,
50+
functions=['knn_clsf.fit', 'knn_clsf.predict'],
51+
times=[train_time, predict_time],
52+
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
53+
data=[X_train, X_test], alg_instance=knn_clsf)

sklearn/knn_clsf.py

100644100755
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
parser.add_argument('--weights', type=str, default='uniform',
1818
help='Weight function used in prediction')
1919
parser.add_argument('--method', type=str, default='brute',
20+
choices=('brute', 'kd_tree', 'ball_tree', 'auto'),
2021
help='Algorithm used to compute the nearest neighbors')
2122
parser.add_argument('--metric', type=str, default='euclidean',
2223
help='Distance metric to use')
@@ -32,16 +33,22 @@
3233
algorithm=params.method,
3334
metric=params.metric)
3435

35-
knn_clsf.fit(X_train, y_train)
36-
# Time predict
37-
time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
36+
# Measure time and accuracy on fitting
37+
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train,
38+
params=params)
39+
y_pred = knn_clsf.predict(X_train)
40+
train_acc = 100 * accuracy_score(y_pred, y_train)
3841

39-
acc = 100 * accuracy_score(yp, y_test)
42+
# Measure time and accuracy on prediction
43+
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
44+
test_acc = 100 * accuracy_score(yp, y_test)
4045

4146
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
4247
'n_neighbors', 'n_classes', 'time')
4348

4449
print_output(library='sklearn', algorithm='knn_classification',
45-
stages=['prediction'], columns=columns, params=params,
46-
functions=['knn_clsf.predict'], times=[time], accuracies=[acc],
47-
accuracy_type='accuracy[%]', data=[X_test], alg_instance=knn_clsf)
50+
stages=['training', 'prediction'], columns=columns, params=params,
51+
functions=['knn_clsf.fit', 'knn_clsf.predict'],
52+
times=[train_time, predict_time],
53+
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
54+
data=[X_train, X_test], alg_instance=knn_clsf)

0 commit comments

Comments
 (0)