Skip to content

Commit bdcedf1

Browse files
committed
Repair patching
1 parent 80bb7c1 commit bdcedf1

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

cuml/bench.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
import sklearn
99
import timeit
1010
import json
11+
import os
12+
import sys
13+
14+
15+
if 'FORCE_DAAL4PY_SKLEARN' in os.environ.keys():
16+
try:
17+
from daal4py.sklearn import patch_sklearn
18+
if os.environ['FORCE_DAAL4PY_SKLEARN'] == 'YES':
19+
patch_sklearn()
20+
except ImportError:
21+
print('Failed to import daal4py.sklearn.patch_sklearn '
22+
'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr)
1123

1224

1325
def get_dtype(data):

daal4py/bench.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
import sklearn
99
import timeit
1010
import json
11+
import os
12+
import sys
13+
14+
15+
if 'FORCE_DAAL4PY_SKLEARN' in os.environ.keys():
16+
try:
17+
from daal4py.sklearn import patch_sklearn
18+
if os.environ['FORCE_DAAL4PY_SKLEARN'] == 'YES':
19+
patch_sklearn()
20+
except ImportError:
21+
print('Failed to import daal4py.sklearn.patch_sklearn '
22+
'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr)
1123

1224

1325
def get_dtype(data):

sklearn/bench.py

100644100755
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
import sklearn
99
import timeit
1010
import json
11+
import os
12+
import sys
13+
14+
15+
if 'FORCE_DAAL4PY_SKLEARN' in os.environ.keys():
16+
try:
17+
from daal4py.sklearn import patch_sklearn
18+
if os.environ['FORCE_DAAL4PY_SKLEARN'] == 'YES':
19+
patch_sklearn()
20+
except ImportError:
21+
print('Failed to import daal4py.sklearn.patch_sklearn '
22+
'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr)
1123

1224

1325
def get_dtype(data):

xgboost/bench.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
import sklearn
99
import timeit
1010
import json
11+
import os
12+
import sys
13+
14+
15+
if 'FORCE_DAAL4PY_SKLEARN' in os.environ.keys():
16+
try:
17+
from daal4py.sklearn import patch_sklearn
18+
if os.environ['FORCE_DAAL4PY_SKLEARN'] == 'YES':
19+
patch_sklearn()
20+
except ImportError:
21+
print('Failed to import daal4py.sklearn.patch_sklearn '
22+
'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr)
1123

1224

1325
def get_dtype(data):

0 commit comments

Comments
 (0)