|
12 | 12 | parser = argparse.ArgumentParser(
|
13 | 13 | description='scikit-learn kNN classifier benchmark')
|
14 | 14 |
|
| 15 | +parser.add_argument('--task', default='classification', type=str, |
| 16 | + choices=('search', 'classification'), |
| 17 | + help='kNN task: search or classification') |
15 | 18 | parser.add_argument('--n-neighbors', default=5, type=int,
|
16 | 19 | help='Number of neighbors to use')
|
17 | 20 | parser.add_argument('--weights', type=str, default='uniform',
|
|
34 | 37 | metric=params.metric)
|
35 | 38 |
|
36 | 39 | # Measure time and accuracy on fitting
|
37 |
| -train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, |
38 |
| - params=params) |
39 |
| -y_pred = knn_clsf.predict(X_train) |
40 |
| -train_acc = 100 * accuracy_score(y_pred, y_train) |
| 40 | +train_time, _ = measure_function_time(knn_clsf.fit, X_train, y_train, params=params) |
| 41 | +if args.task == 'classification': |
| 42 | + y_pred = knn_clsf.predict(X_train) |
| 43 | + train_acc = 100 * accuracy_score(y_pred, y_train) |
41 | 44 |
|
42 | 45 | # Measure time and accuracy on prediction
|
43 |
| -predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params) |
44 |
| -test_acc = 100 * accuracy_score(yp, y_test) |
| 46 | +if args.task == 'classification': |
| 47 | + predict_time, yp = measure_function_time(knn_clsf.predict, X_test, params=params) |
| 48 | + test_acc = 100 * accuracy_score(yp, y_test) |
| 49 | +else: |
| 50 | + predict_time, _ = measure_function_time(knn_clsf.kneighbors, X_test, params=params) |
45 | 51 |
|
46 | 52 | columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
|
47 | 53 | 'n_neighbors', 'n_classes', 'time')
|
48 | 54 |
|
49 |
| -print_output(library='sklearn', algorithm='knn_classification', |
50 |
| - stages=['training', 'prediction'], columns=columns, params=params, |
51 |
| - functions=['knn_clsf.fit', 'knn_clsf.predict'], |
52 |
| - times=[train_time, predict_time], |
53 |
| - accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]', |
54 |
| - data=[X_train, X_test], alg_instance=knn_clsf) |
| 55 | +if args.task == 'classification': |
| 56 | + print_output(library='sklearn', algorithm='knn_classification', |
| 57 | + stages=['training', 'prediction'], columns=columns, params=params, |
| 58 | + functions=['knn_clsf.fit', 'knn_clsf.predict'], |
| 59 | + times=[train_time, predict_time], |
| 60 | + accuracies=[train_acc, test_acc], accuracy_type='accuracy[%]', |
| 61 | + data=[X_train, X_test], alg_instance=knn_clsf) |
| 62 | +else: |
| 63 | + print_output(library='sklearn', algorithm='knn_search', |
| 64 | + stages=['training', 'search'], columns=columns, params=params, |
| 65 | + functions=['knn_clsf.fit', 'knn_clsf.kneighbors'], |
| 66 | + times=[train_time, predict_time], |
| 67 | + data=[X_train, X_test], alg_instance=knn_clsf) |
0 commit comments