|
| 1 | +# Copyright (C) 2017-2019 Intel Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: MIT |
| 4 | + |
| 5 | +import argparse |
| 6 | +from bench import parse_args, time_mean_min, print_header, print_row, size_str |
| 7 | +from daal4py import dbscan |
| 8 | +from daal4py.sklearn.utils import getFPType |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +parser = argparse.ArgumentParser(description='daal4py DBSCAN clustering ' |
| 12 | + 'benchmark') |
| 13 | +parser.add_argument('-x', '--filex', '--fileX', '--input', required=True, |
| 14 | + type=str, help='Points to cluster') |
| 15 | +parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10, |
| 16 | + help='Radius of neighborhood of a point') |
| 17 | +parser.add_argument('-m', '--data-multiplier', default=100, |
| 18 | + type=int, help='Data multiplier') |
| 19 | +parser.add_argument('-M', '--min-samples', default=5, type=int, |
| 20 | + help='The minimum number of samples required in a ' |
| 21 | + 'neighborhood to consider a point a core point') |
| 22 | +params = parse_args(parser, prefix='daal4py') |
| 23 | + |
| 24 | +# Load generated data |
| 25 | +X = np.load(params.filex) |
| 26 | +X_mult = np.vstack((X,) * params.data_multiplier) |
| 27 | + |
| 28 | +params.size = size_str(X.shape) |
| 29 | +params.dtype = X.dtype |
| 30 | + |
| 31 | + |
| 32 | +# Define functions to time |
| 33 | +def test_dbscan(X): |
| 34 | + algorithm = dbscan( |
| 35 | + fptype=getFPType(X), |
| 36 | + epsilon=params.eps, |
| 37 | + minObservations=params.min_samples, |
| 38 | + resultsToCompute='computeCoreIndices' |
| 39 | + ) |
| 40 | + return algorithm.compute(X) |
| 41 | + |
| 42 | + |
| 43 | +columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size', |
| 44 | + 'n_clusters', 'time') |
| 45 | +print_header(columns, params) |
| 46 | + |
| 47 | +# Time clustering |
| 48 | +time, result = time_mean_min(test_dbscan, X, |
| 49 | + outer_loops=params.outer_loops, |
| 50 | + inner_loops=params.inner_loops, |
| 51 | + goal_outer_loops=params.goal, |
| 52 | + time_limit=params.time_limit, |
| 53 | + verbose=params.verbose) |
| 54 | +params.n_clusters = result.nClusters[0, 0] |
| 55 | +print_row(columns, params, function='DBSCAN', time=time) |
0 commit comments