Skip to content

Commit b408ce8

Browse files
authored
Fix parameters for RF in config and fix n_jobs for original scikit-learn (#51)
1 parent 8bbf7d6 commit b408ce8

File tree

6 files changed

+9
-7
lines changed

6 files changed

+9
-7
lines changed

bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_optimal_cache_size(n_rows, dtype=np.double, max_cache=64):
9595

9696

9797
def parse_args(parser, size=None, loop_types=(),
98-
n_jobs_supported=False, prefix='sklearn'):
98+
n_jobs_supported=True, prefix='sklearn'):
9999
'''
100100
Add common arguments useful for most benchmarks and parse.
101101

configs/skl_config.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@
146146
"num-trees": [50],
147147
"max-depth": [16],
148148
"max-leaf-nodes": [131072],
149-
"max-features": [0.2],
150-
"use-sklearn-class": [""]
149+
"max-features": [0.2]
151150
},
152151
{
153152
"algorithm": "ridge",

sklearn_bench/dbscan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
parser.add_argument('-m', '--min-samples', default=5, type=int,
2828
help='The minimum number of samples required in a '
2929
'neighborhood to consider a point a core point')
30-
params = bench.parse_args(parser, n_jobs_supported=True)
30+
params = bench.parse_args(parser)
3131

3232
from sklearn.cluster import DBSCAN
3333

sklearn_bench/df_clsf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@
6060
max_leaf_nodes=params.max_leaf_nodes,
6161
min_impurity_decrease=params.min_impurity_decrease,
6262
bootstrap=params.bootstrap,
63-
random_state=params.seed)
63+
random_state=params.seed,
64+
n_jobs=params.n_jobs)
6465

6566
params.n_classes = len(np.unique(y_train))
6667

sklearn_bench/df_regr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
max_leaf_nodes=params.max_leaf_nodes,
5959
min_impurity_decrease=params.min_impurity_decrease,
6060
bootstrap=params.bootstrap,
61-
random_state=params.seed)
61+
random_state=params.seed,
62+
n_jobs=params.n_jobs)
6263

6364
fit_time, _ = bench.measure_function_time(regr.fit, X_train, y_train, params=params)
6465

sklearn_bench/knn_clsf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
knn_clsf = KNeighborsClassifier(n_neighbors=params.n_neighbors,
5050
weights=params.weights,
5151
algorithm=params.method,
52-
metric=params.metric)
52+
metric=params.metric,
53+
n_jobs=params.n_jobs)
5354

5455
# Measure time and accuracy on fitting
5556
train_time, _ = bench.measure_function_time(knn_clsf.fit, X_train, y_train, params=params)

0 commit comments

Comments
 (0)