Skip to content

Commit 0bbcbee

Browse files
committed
Add daal4py dbscan benchmark
1 parent 22949e8 commit 0bbcbee

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

daal4py/dbscan.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (C) 2017-2019 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import argparse
6+
from bench import parse_args, time_mean_min, print_header, print_row, size_str
7+
from daal4py import dbscan
8+
from daal4py.sklearn.utils import getFPType
9+
import numpy as np
10+
11+
parser = argparse.ArgumentParser(description='daal4py DBSCAN clustering '
12+
'benchmark')
13+
parser.add_argument('-x', '--filex', '--fileX', '--input', required=True,
14+
type=str, help='Points to cluster')
15+
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10,
16+
help='Radius of neighborhood of a point')
17+
parser.add_argument('-m', '--data-multiplier', default=100,
18+
type=int, help='Data multiplier')
19+
parser.add_argument('-M', '--min-samples', default=5, type=int,
20+
help='The minimum number of samples required in a '
21+
'neighborhood to consider a point a core point')
22+
params = parse_args(parser, prefix='daal4py')
23+
24+
# Load generated data
25+
X = np.load(params.filex)
26+
X_mult = np.vstack((X,) * params.data_multiplier)
27+
28+
params.size = size_str(X.shape)
29+
params.dtype = X.dtype
30+
31+
32+
# Define functions to time
33+
def test_dbscan(X):
34+
algorithm = dbscan(
35+
fptype=getFPType(X),
36+
epsilon=params.eps,
37+
minObservations=params.min_samples,
38+
resultsToCompute='computeCoreIndices'
39+
)
40+
return algorithm.compute(X)
41+
42+
43+
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
44+
'n_clusters', 'time')
45+
print_header(columns, params)
46+
47+
# Time clustering
48+
time, result = time_mean_min(test_dbscan, X,
49+
outer_loops=params.outer_loops,
50+
inner_loops=params.inner_loops,
51+
goal_outer_loops=params.goal,
52+
time_limit=params.time_limit,
53+
verbose=params.verbose)
54+
params.n_clusters = result.nClusters[0, 0]
55+
print_row(columns, params, function='DBSCAN', time=time)

0 commit comments

Comments
 (0)