Skip to content

Commit 1baf3fd

Browse files
committed
GetFPType fix
1 parent 4dcd1ce commit 1baf3fd

File tree

14 files changed

+120
-70
lines changed

14 files changed

+120
-70
lines changed

cuml/bench.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@
1010
import json
1111

1212

13+
def get_dtype(data):
14+
'''
15+
Get type of input data as numpy.dtype
16+
'''
17+
if hasattr(data, 'dtype'):
18+
return data.dtype
19+
elif hasattr(data, 'dtypes'):
20+
return str(data.dtypes[0])
21+
elif hasattr(data, 'values'):
22+
return data.values.dtype
23+
else:
24+
raise ValueError(f'Impossible to get data type of {type(data)}')
25+
26+
27+
try:
28+
from daal4py.sklearn._utils import getFPType
29+
except ImportError:
30+
def getFPType(X):
31+
dtype = str(get_dtype(X))
32+
if 'float32' in dtype:
33+
return 'float'
34+
elif 'float64' in dtype:
35+
return 'double'
36+
else:
37+
ValueError('Unknown type')
38+
39+
1340
def sklearn_disable_finiteness_check():
1441
try:
1542
sklearn.set_config(assume_finite=True)
@@ -427,18 +454,6 @@ def convert_data(data, dtype, data_order, data_format):
427454
return cudf.DataFrame.from_pandas(pd.DataFrame(data))
428455

429456

430-
def get_dtype(data):
431-
'''
432-
Get type of input data as numpy.dtype
433-
'''
434-
if hasattr(data, 'dtype'):
435-
return data.dtype
436-
elif hasattr(data, 'dtypes'):
437-
return str(data.dtypes[0])
438-
elif hasattr(data, 'values'):
439-
return data.values.dtype
440-
441-
442457
def read_csv(filename, params):
443458
from string import ascii_lowercase, ascii_uppercase
444459

daal4py/bench.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@
1010
import json
1111

1212

13+
def get_dtype(data):
14+
'''
15+
Get type of input data as numpy.dtype
16+
'''
17+
if hasattr(data, 'dtype'):
18+
return data.dtype
19+
elif hasattr(data, 'dtypes'):
20+
return str(data.dtypes[0])
21+
elif hasattr(data, 'values'):
22+
return data.values.dtype
23+
else:
24+
raise ValueError(f'Impossible to get data type of {type(data)}')
25+
26+
27+
try:
28+
from daal4py.sklearn._utils import getFPType
29+
except ImportError:
30+
def getFPType(X):
31+
dtype = str(get_dtype(X))
32+
if 'float32' in dtype:
33+
return 'float'
34+
elif 'float64' in dtype:
35+
return 'double'
36+
else:
37+
ValueError('Unknown type')
38+
39+
1340
def sklearn_disable_finiteness_check():
1441
try:
1542
sklearn.set_config(assume_finite=True)
@@ -427,18 +454,6 @@ def convert_data(data, dtype, data_order, data_format):
427454
return cudf.DataFrame.from_pandas(pd.DataFrame(data))
428455

429456

430-
def get_dtype(data):
431-
'''
432-
Get type of input data as numpy.dtype
433-
'''
434-
if hasattr(data, 'dtype'):
435-
return data.dtype
436-
elif hasattr(data, 'dtypes'):
437-
return str(data.dtypes[0])
438-
elif hasattr(data, 'values'):
439-
return data.values.dtype
440-
441-
442457
def read_csv(filename, params):
443458
from string import ascii_lowercase, ascii_uppercase
444459

daal4py/dbscan.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
import argparse
66
from bench import (
7-
parse_args, measure_function_time, load_data, print_output,
8-
import_fptype_getter
7+
parse_args, measure_function_time, load_data, print_output, getFPType
98
)
109
from daal4py import dbscan
11-
getFPType = import_fptype_getter()
1210

1311

1412
parser = argparse.ArgumentParser(description='daal4py DBSCAN clustering '

daal4py/df_clsf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import argparse
66
from bench import (
77
parse_args, measure_function_time, load_data, print_output, accuracy_score,
8-
float_or_int, import_fptype_getter
8+
float_or_int, getFPType
99
)
1010
import numpy as np
1111
from daal4py import (
1212
decision_forest_classification_training,
1313
decision_forest_classification_prediction, engines_mt2203
1414
)
15-
getFPType = import_fptype_getter()
1615

1716

1817
def df_clsf_fit(X, y, n_classes, n_trees=100, seed=12345,

daal4py/df_regr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import argparse
77
from bench import (
88
parse_args, measure_function_time, load_data, print_output, rmse_score,
9-
float_or_int, import_fptype_getter
9+
float_or_int, getFPType
1010
)
1111
from daal4py import (
1212
decision_forest_regression_training,
1313
decision_forest_regression_prediction,
1414
engines_mt2203
1515
)
16-
getFPType = import_fptype_getter()
1716

1817

1918
def df_regr_fit(X, y, n_trees=100, seed=12345, n_features_per_node=0,

daal4py/distances.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import argparse
66
from bench import (
77
parse_args, measure_function_time, print_output, load_data,
8-
import_fptype_getter
8+
getFPType
99
)
1010
import daal4py
11-
getFPType = import_fptype_getter()
1211

1312

1413
def compute_distances(pairwise_distances, X):

daal4py/kmeans.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44

55
import argparse
66
from bench import (
7-
parse_args, measure_function_time, load_data, print_output,
8-
import_fptype_getter
7+
parse_args, measure_function_time, load_data, print_output, getFPType
98
)
109
import numpy as np
1110
from daal4py import kmeans
12-
getFPType = import_fptype_getter()
1311

1412

1513
parser = argparse.ArgumentParser(description='daal4py K-Means clustering '

daal4py/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import argparse
66
from bench import (
77
parse_args, measure_function_time, load_data, print_output, rmse_score,
8-
import_fptype_getter
8+
getFPType
99
)
1010
from daal4py import linear_regression_training, linear_regression_prediction
11-
getFPType = import_fptype_getter()
1211

1312

1413
parser = argparse.ArgumentParser(description='daal4py linear regression '

daal4py/log_reg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import argparse
66
from bench import (
77
parse_args, measure_function_time, load_data, print_output, accuracy_score,
8-
import_fptype_getter
8+
getFPType
99
)
1010
import numpy as np
1111
import daal4py
1212
from daal4py import math_logistic, math_softmax
1313
from daal4py.sklearn.utils import make2d
1414
import scipy.optimize
15-
getFPType = import_fptype_getter()
15+
1616

1717
_logistic_loss = daal4py.optimization_solver_logistic_loss
1818
_cross_entropy_loss = daal4py.optimization_solver_cross_entropy_loss

daal4py/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import argparse
66
from bench import (
77
parse_args, measure_function_time, load_data, print_output,
8-
import_fptype_getter
8+
getFPType
99
)
1010
import numpy as np
1111
from daal4py import pca, pca_transform, normalization_zscore
1212
from sklearn.utils.extmath import svd_flip
13-
getFPType = import_fptype_getter()
13+
1414

1515
parser = argparse.ArgumentParser(description='daal4py PCA benchmark')
1616
parser.add_argument('--svd-solver', type=str,

0 commit comments

Comments
 (0)