Skip to content

Commit 97ecf0c

Browse files
committed
Initial version of scikit-learn DBSCAN benchmark
1 parent d735f55 commit 97ecf0c

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

sklearn/bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def parse_args(parser, size=None, dtypes=None, loop_types=(),
121121

122122
n_jobs = None
123123
if n_jobs_supported and not daal_version:
124-
n_jobs = num_threads = params.num_threads
124+
n_jobs = num_threads = params.threads
125125

126126
# Set threading and DAAL related params here
127127
setattr(params, 'threads', num_threads)

sklearn/dbscan.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (C) 2017-2019 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import argparse
6+
from bench import parse_args, time_mean_min, print_header, print_row, size_str
7+
import numpy as np
8+
from sklearn.cluster import DBSCAN
9+
10+
parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
11+
parser.add_argument('-x', '--filex', '--fileX', '--input', required=True,
12+
type=str, help='Points to cluster')
13+
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=0.5,
14+
help='Radius of neighborhood of a point')
15+
parser.add_argument('-m', '--data-multiplier', default=100,
16+
type=int, help='Data multiplier')
17+
parser.add_argument('-M', '--min-samples', default=5, type=int,
18+
help='The minimum number of samples required in a '
19+
'neighborhood to consider a point a core point')
20+
params = parse_args(parser, loop_types=('fit', 'predict'), n_jobs_supported=True)
21+
22+
# Load generated data
23+
X = np.load(params.filex)
24+
X_mult = np.vstack((X,) * params.data_multiplier)
25+
26+
# Create our clustering object
27+
dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs,
28+
min_samples=params.min_samples, metric='euclidean',
29+
algorithm='auto')
30+
31+
# N.B. algorithm='auto' will select DAAL's brute force method when running
32+
# daal4py-patched scikit-learn.
33+
34+
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
35+
'n_clusters', 'time')
36+
params.size = size_str(X.shape)
37+
params.dtype = X.dtype
38+
print_header(columns, params)
39+
40+
# Time fit
41+
fit_time, _ = time_mean_min(dbscan.fit, X,
42+
outer_loops=params.fit_outer_loops,
43+
inner_loops=params.fit_inner_loops,
44+
goal_outer_loops=params.fit_goal,
45+
time_limit=params.fit_time_limit,
46+
verbose=params.verbose)
47+
params.n_clusters = len(dbscan.core_sample_indices_)
48+
print_row(columns, params, function='DBSCAN.fit', time=fit_time)
49+
50+
# Time predict
51+
predict_time, _ = time_mean_min(dbscan.fit_predict, X,
52+
outer_loops=params.predict_outer_loops,
53+
inner_loops=params.predict_inner_loops,
54+
goal_outer_loops=params.predict_goal,
55+
time_limit=params.predict_time_limit,
56+
verbose=params.verbose)
57+
print_row(columns, params, function='DBSCAN.fit_predict', time=predict_time)

0 commit comments

Comments
 (0)