Skip to content

Commit e6296c3

Browse files
committed
add knn search
1 parent b1f19c7 commit e6296c3

File tree

3 files changed

+52
-26
lines changed

3 files changed

+52
-26
lines changed

cuml/knn_clsf.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
parser = argparse.ArgumentParser(
1313
description='cuML kNN classifier benchmark')
1414

15+
parser.add_argument('--task', default='classification', type=str,
16+
choices=('search', 'classification'),
17+
help='kNN task: search or classification')
1518
parser.add_argument('--n-neighbors', default=5, type=int,
1619
help='Number of neighbors to use')
1720
parser.add_argument('--weights', type=str, default='uniform',
@@ -33,21 +36,31 @@
3336
metric=params.metric)
3437

3538
# Measure time and accuracy on fitting
36-
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train,
37-
params=params)
38-
y_pred = knn_clsf.predict(X_train)
39-
train_acc = 100 * accuracy_score(y_pred, y_train)
39+
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, params=params)
40+
if args.task == 'classification':
41+
y_pred = knn_clsf.predict(X_train)
42+
train_acc = 100 * accuracy_score(y_pred, y_train)
4043

4144
# Measure time and accuracy on prediction
42-
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
43-
test_acc = 100 * accuracy_score(yp, y_test)
45+
if args.task == 'classification':
46+
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
47+
test_acc = 100 * accuracy_score(yp, y_test)
48+
else:
49+
predict_time, _ = measure_function_time(knn_clsf.kneighbors, X_test, params=params)
4450

4551
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
4652
'n_neighbors', 'n_classes', 'time')
4753

48-
print_output(library='cuml', algorithm='knn_classification',
49-
stages=['training', 'prediction'], columns=columns, params=params,
50-
functions=['knn_clsf.fit', 'knn_clsf.predict'],
51-
times=[train_time, predict_time],
52-
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
53-
data=[X_train, X_test], alg_instance=knn_clsf)
54+
if args.task == 'classification':
55+
print_output(library='cuml', algorithm='knn_classification',
56+
stages=['training', 'prediction'], columns=columns, params=params,
57+
functions=['knn_clsf.fit', 'knn_clsf.predict'],
58+
times=[train_time, predict_time],
59+
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
60+
data=[X_train, X_test], alg_instance=knn_clsf)
61+
else:
62+
print_output(library='cuml', algorithm='knn_search',
63+
stages=['training', 'search'], columns=columns, params=params,
64+
functions=['knn_clsf.fit', 'knn_clsf.kneighbors'],
65+
times=[train_time, predict_time],
66+
data=[X_train, X_test], alg_instance=knn_clsf)

runner.py

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ class GenerationArgs:
288288
if args.output_format == 'json':
289289
try:
290290
json_result['results'].extend(json.loads(stdout))
291-
except json.JSONDecodeError:
292-
pass
291+
except json.JSONDecodeError as decoding_exception:
292+
stderr += str(decoding_exception) + '\n'
293293
elif args.output_format == 'csv':
294294
csv_result += stdout + '\n'
295295
if stderr != '':

sklearn/knn_clsf.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
parser = argparse.ArgumentParser(
1313
description='scikit-learn kNN classifier benchmark')
1414

15+
parser.add_argument('--task', default='classification', type=str,
16+
choices=('search', 'classification'),
17+
help='kNN task: search or classification')
1518
parser.add_argument('--n-neighbors', default=5, type=int,
1619
help='Number of neighbors to use')
1720
parser.add_argument('--weights', type=str, default='uniform',
@@ -34,21 +37,31 @@
3437
metric=params.metric)
3538

3639
# Measure time and accuracy on fitting
37-
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train,
38-
params=params)
39-
y_pred = knn_clsf.predict(X_train)
40-
train_acc = 100 * accuracy_score(y_pred, y_train)
40+
train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, params=params)
41+
if args.task == 'classification':
42+
y_pred = knn_clsf.predict(X_train)
43+
train_acc = 100 * accuracy_score(y_pred, y_train)
4144

4245
# Measure time and accuracy on prediction
43-
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
44-
test_acc = 100 * accuracy_score(yp, y_test)
46+
if args.task == 'classification':
47+
predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params)
48+
test_acc = 100 * accuracy_score(yp, y_test)
49+
else:
50+
predict_time, _ = measure_function_time(knn_clsf.kneighbors, X_test, params=params)
4551

4652
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
4753
'n_neighbors', 'n_classes', 'time')
4854

49-
print_output(library='sklearn', algorithm='knn_classification',
50-
stages=['training', 'prediction'], columns=columns, params=params,
51-
functions=['knn_clsf.fit', 'knn_clsf.predict'],
52-
times=[train_time, predict_time],
53-
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
54-
data=[X_train, X_test], alg_instance=knn_clsf)
55+
if args.task == 'classification':
56+
print_output(library='sklearn', algorithm='knn_classification',
57+
stages=['training', 'prediction'], columns=columns, params=params,
58+
functions=['knn_clsf.fit', 'knn_clsf.predict'],
59+
times=[train_time, predict_time],
60+
accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]',
61+
data=[X_train, X_test], alg_instance=knn_clsf)
62+
else:
63+
print_output(library='sklearn', algorithm='knn_search',
64+
stages=['training', 'search'], columns=columns, params=params,
65+
functions=['knn_clsf.fit', 'knn_clsf.kneighbors'],
66+
times=[train_time, predict_time],
67+
data=[X_train, X_test], alg_instance=knn_clsf)

0 commit comments

Comments
 (0)