Skip to content

Commit 3d91028

Browse files
Merge pull request #42 from Alexsandruss/dev/fix-daal-version
Remove daal_version argument from benchmarks
2 parents ba0250c + 015c90d commit 3d91028

File tree

5 files changed

+39
-57
lines changed

5 files changed

+39
-57
lines changed

cuml/bench.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(),
184184
sklearn_disable_finiteness_check()
185185

186186
# Ask DAAL what it thinks about this number of threads
187-
num_threads, daal_version = prepare_daal(num_threads=params.threads)
188-
if params.verbose and daal_version:
189-
print(f'@ Found DAAL version {daal_version}')
187+
num_threads = prepare_daal_threads(num_threads=params.threads)
188+
if params.verbose:
190189
print(f'@ DAAL gave us {num_threads} threads')
191190

192191
n_jobs = None
193-
if n_jobs_supported and not daal_version:
192+
if n_jobs_supported:
194193
n_jobs = num_threads = params.threads
195194

196195
# Set threading and DAAL related params here
197196
setattr(params, 'threads', num_threads)
198-
setattr(params, 'daal_version', daal_version)
199-
setattr(params, 'using_daal', daal_version is not None)
200197
setattr(params, 'n_jobs', n_jobs)
201198

202199
# Set size string parameter for easy printing
@@ -243,18 +240,16 @@ def set_daal_num_threads(num_threads):
243240
'is being ignored')
244241

245242

246-
def prepare_daal(num_threads=-1):
243+
def prepare_daal_threads(num_threads=-1):
247244
try:
248245
if num_threads > 0:
249246
set_daal_num_threads(num_threads)
250247
import daal4py
251248
num_threads = daal4py.num_threads()
252-
daal_version = daal4py._get__daal_run_version__()
253249
except ImportError:
254250
num_threads = 1
255-
daal_version = None
256251

257-
return num_threads, daal_version
252+
return num_threads
258253

259254

260255
def measure_function_time(func, *args, params, **kwargs):

daal4py/bench.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(),
184184
sklearn_disable_finiteness_check()
185185

186186
# Ask DAAL what it thinks about this number of threads
187-
num_threads, daal_version = prepare_daal(num_threads=params.threads)
188-
if params.verbose and daal_version:
189-
print(f'@ Found DAAL version {daal_version}')
187+
num_threads = prepare_daal_threads(num_threads=params.threads)
188+
if params.verbose:
190189
print(f'@ DAAL gave us {num_threads} threads')
191190

192191
n_jobs = None
193-
if n_jobs_supported and not daal_version:
192+
if n_jobs_supported:
194193
n_jobs = num_threads = params.threads
195194

196195
# Set threading and DAAL related params here
197196
setattr(params, 'threads', num_threads)
198-
setattr(params, 'daal_version', daal_version)
199-
setattr(params, 'using_daal', daal_version is not None)
200197
setattr(params, 'n_jobs', n_jobs)
201198

202199
# Set size string parameter for easy printing
@@ -243,18 +240,16 @@ def set_daal_num_threads(num_threads):
243240
'is being ignored')
244241

245242

246-
def prepare_daal(num_threads=-1):
243+
def prepare_daal_threads(num_threads=-1):
247244
try:
248245
if num_threads > 0:
249246
set_daal_num_threads(num_threads)
250247
import daal4py
251248
num_threads = daal4py.num_threads()
252-
daal_version = daal4py._get__daal_run_version__()
253249
except ImportError:
254250
num_threads = 1
255-
daal_version = None
256251

257-
return num_threads, daal_version
252+
return num_threads
258253

259254

260255
def measure_function_time(func, *args, params, **kwargs):

modelbuilders/bench.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
import sklearn
99
import timeit
1010
import json
11+
import os
12+
import sys
13+
14+
15+
if os.environ.get('FORCE_DAAL4PY_SKLEARN', False) in ['y', 'yes', 'Y', 'YES', 'Yes']:
16+
try:
17+
from daal4py.sklearn import patch_sklearn
18+
patch_sklearn()
19+
except ImportError:
20+
print('Failed to import daal4py.sklearn.patch_sklearn '
21+
'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr)
1122

1223

1324
def get_dtype(data):
@@ -173,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(),
173184
sklearn_disable_finiteness_check()
174185

175186
# Ask DAAL what it thinks about this number of threads
176-
num_threads, daal_version = prepare_daal(num_threads=params.threads)
177-
if params.verbose and daal_version:
178-
print(f'@ Found DAAL version {daal_version}')
187+
num_threads = prepare_daal_threads(num_threads=params.threads)
188+
if params.verbose:
179189
print(f'@ DAAL gave us {num_threads} threads')
180190

181191
n_jobs = None
182-
if n_jobs_supported and not daal_version:
192+
if n_jobs_supported:
183193
n_jobs = num_threads = params.threads
184194

185195
# Set threading and DAAL related params here
186196
setattr(params, 'threads', num_threads)
187-
setattr(params, 'daal_version', daal_version)
188-
setattr(params, 'using_daal', daal_version is not None)
189197
setattr(params, 'n_jobs', n_jobs)
190198

191199
# Set size string parameter for easy printing
@@ -232,18 +240,16 @@ def set_daal_num_threads(num_threads):
232240
'is being ignored')
233241

234242

235-
def prepare_daal(num_threads=-1):
243+
def prepare_daal_threads(num_threads=-1):
236244
try:
237245
if num_threads > 0:
238246
set_daal_num_threads(num_threads)
239247
import daal4py
240248
num_threads = daal4py.num_threads()
241-
daal_version = daal4py._get__daal_run_version__()
242249
except ImportError:
243250
num_threads = 1
244-
daal_version = None
245251

246-
return num_threads, daal_version
252+
return num_threads
247253

248254

249255
def measure_function_time(func, *args, params, **kwargs):
@@ -508,15 +514,11 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
508514
params.data_order, params.data_format)
509515
# convert existing labels from 1- to 2-dimensional
510516
# if it's forced and possible
511-
if full_data[element] is not None and 'y' in element and label_2d and hasattr(
512-
full_data[element],
513-
'reshape'):
517+
if full_data[element] is not None and 'y' in element and label_2d and hasattr(full_data[element], 'reshape'):
514518
full_data[element] = full_data[element].reshape(
515519
(full_data[element].shape[0], 1))
516520
# add dtype property to data if it's needed and doesn't exist
517-
if full_data[element] is not None and add_dtype and not hasattr(
518-
full_data[element],
519-
'dtype'):
521+
if full_data[element] is not None and add_dtype and not hasattr(full_data[element], 'dtype'):
520522
if hasattr(full_data[element], 'values'):
521523
full_data[element].dtype = full_data[element].values.dtype
522524
elif hasattr(full_data[element], 'dtypes'):
@@ -608,6 +610,6 @@ def print_output(library, algorithm, stages, columns, params, functions,
608610
def import_fptype_getter():
609611
try:
610612
from daal4py.sklearn._utils import getFPType
611-
except ImportError:
613+
except:
612614
from daal4py.sklearn.utils import getFPType
613615
return getFPType

sklearn/bench.py

100755100644
Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(),
184184
sklearn_disable_finiteness_check()
185185

186186
# Ask DAAL what it thinks about this number of threads
187-
num_threads, daal_version = prepare_daal(num_threads=params.threads)
188-
if params.verbose and daal_version:
189-
print(f'@ Found DAAL version {daal_version}')
187+
num_threads = prepare_daal_threads(num_threads=params.threads)
188+
if params.verbose:
190189
print(f'@ DAAL gave us {num_threads} threads')
191190

192191
n_jobs = None
193-
if n_jobs_supported and not daal_version:
192+
if n_jobs_supported:
194193
n_jobs = num_threads = params.threads
195194

196195
# Set threading and DAAL related params here
197196
setattr(params, 'threads', num_threads)
198-
setattr(params, 'daal_version', daal_version)
199-
setattr(params, 'using_daal', daal_version is not None)
200197
setattr(params, 'n_jobs', n_jobs)
201198

202199
# Set size string parameter for easy printing
@@ -243,18 +240,16 @@ def set_daal_num_threads(num_threads):
243240
'is being ignored')
244241

245242

246-
def prepare_daal(num_threads=-1):
243+
def prepare_daal_threads(num_threads=-1):
247244
try:
248245
if num_threads > 0:
249246
set_daal_num_threads(num_threads)
250247
import daal4py
251248
num_threads = daal4py.num_threads()
252-
daal_version = daal4py._get__daal_run_version__()
253249
except ImportError:
254250
num_threads = 1
255-
daal_version = None
256251

257-
return num_threads, daal_version
252+
return num_threads
258253

259254

260255
def measure_function_time(func, *args, params, **kwargs):

xgboost/bench.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(),
184184
sklearn_disable_finiteness_check()
185185

186186
# Ask DAAL what it thinks about this number of threads
187-
num_threads, daal_version = prepare_daal(num_threads=params.threads)
188-
if params.verbose and daal_version:
189-
print(f'@ Found DAAL version {daal_version}')
187+
num_threads = prepare_daal_threads(num_threads=params.threads)
188+
if params.verbose:
190189
print(f'@ DAAL gave us {num_threads} threads')
191190

192191
n_jobs = None
193-
if n_jobs_supported and not daal_version:
192+
if n_jobs_supported:
194193
n_jobs = num_threads = params.threads
195194

196195
# Set threading and DAAL related params here
197196
setattr(params, 'threads', num_threads)
198-
setattr(params, 'daal_version', daal_version)
199-
setattr(params, 'using_daal', daal_version is not None)
200197
setattr(params, 'n_jobs', n_jobs)
201198

202199
# Set size string parameter for easy printing
@@ -243,18 +240,16 @@ def set_daal_num_threads(num_threads):
243240
'is being ignored')
244241

245242

246-
def prepare_daal(num_threads=-1):
243+
def prepare_daal_threads(num_threads=-1):
247244
try:
248245
if num_threads > 0:
249246
set_daal_num_threads(num_threads)
250247
import daal4py
251248
num_threads = daal4py.num_threads()
252-
daal_version = daal4py._get__daal_run_version__()
253249
except ImportError:
254250
num_threads = 1
255-
daal_version = None
256251

257-
return num_threads, daal_version
252+
return num_threads
258253

259254

260255
def measure_function_time(func, *args, params, **kwargs):

0 commit comments

Comments
 (0)