|
| 1 | +# Copyright (C) 2017-2019 Intel Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: MIT |
| 4 | + |
| 5 | +import argparse |
| 6 | +from bench import parse_args, time_mean_min, print_header, print_row, size_str |
| 7 | +import numpy as np |
| 8 | +from sklearn.cluster import DBSCAN |
| 9 | + |
| 10 | +parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark') |
| 11 | +parser.add_argument('-x', '--filex', '--fileX', '--input', required=True, |
| 12 | + type=str, help='Points to cluster') |
| 13 | +parser.add_argument('-e', '--eps', '--epsilon', type=float, default=0.5, |
| 14 | + help='Radius of neighborhood of a point') |
| 15 | +parser.add_argument('-m', '--data-multiplier', default=100, |
| 16 | + type=int, help='Data multiplier') |
| 17 | +parser.add_argument('-M', '--min-samples', default=5, type=int, |
| 18 | + help='The minimum number of samples required in a ' |
| 19 | + 'neighborhood to consider a point a core point') |
| 20 | +params = parse_args(parser, loop_types=('fit', 'predict'), n_jobs_supported=True) |
| 21 | + |
| 22 | +# Load generated data |
| 23 | +X = np.load(params.filex) |
| 24 | +X_mult = np.vstack((X,) * params.data_multiplier) |
| 25 | + |
| 26 | +# Create our clustering object |
| 27 | +dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs, |
| 28 | + min_samples=params.min_samples, metric='euclidean', |
| 29 | + algorithm='auto') |
| 30 | + |
| 31 | +# N.B. algorithm='auto' will select DAAL's brute force method when running |
| 32 | +# daal4py-patched scikit-learn. |
| 33 | + |
| 34 | +columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size', |
| 35 | + 'n_clusters', 'time') |
| 36 | +params.size = size_str(X.shape) |
| 37 | +params.dtype = X.dtype |
| 38 | +print_header(columns, params) |
| 39 | + |
| 40 | +# Time fit |
| 41 | +fit_time, _ = time_mean_min(dbscan.fit, X, |
| 42 | + outer_loops=params.fit_outer_loops, |
| 43 | + inner_loops=params.fit_inner_loops, |
| 44 | + goal_outer_loops=params.fit_goal, |
| 45 | + time_limit=params.fit_time_limit, |
| 46 | + verbose=params.verbose) |
| 47 | +params.n_clusters = len(dbscan.core_sample_indices_) |
| 48 | +print_row(columns, params, function='DBSCAN.fit', time=fit_time) |
| 49 | + |
| 50 | +# Time predict |
| 51 | +predict_time, _ = time_mean_min(dbscan.fit_predict, X, |
| 52 | + outer_loops=params.predict_outer_loops, |
| 53 | + inner_loops=params.predict_inner_loops, |
| 54 | + goal_outer_loops=params.predict_goal, |
| 55 | + time_limit=params.predict_time_limit, |
| 56 | + verbose=params.verbose) |
| 57 | +print_row(columns, params, function='DBSCAN.fit_predict', time=predict_time) |
0 commit comments