Skip to content

Commit ece7e36

Browse files
authored
Normalizes coefficients in the converter for sparml DecisionTreeClassifier (#589)
* add pprint Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * update version Signed-off-by: xadupre <[email protected]> * change range to a list Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * debgu Signed-off-by: xadupre <[email protected]> * debug Signed-off-by: xadupre <[email protected]> * fix attribute value Signed-off-by: xadupre <[email protected]> * import code from onnxconverter_common Signed-off-by: xadupre <[email protected]> * flake Signed-off-by: xadupre <[email protected]> * remove spaces Signed-off-by: xadupre <[email protected]> * add a test to investigate categories Signed-off-by: xadupre <[email protected]> * still working Signed-off-by: xadupre <[email protected]> * error message when IN SET rule is detected Signed-off-by: xadupre <[email protected]> * fix rule || Signed-off-by: xadupre <[email protected]> * fix tree ids and unrelated attributes Signed-off-by: xadupre <[email protected]> * make it way faster Signed-off-by: xadupre <[email protected]> * update version Signed-off-by: xadupre <[email protected]> * add debug code Signed-off-by: xadupre <[email protected]> * fix node ids Signed-off-by: xadupre <[email protected]> * update version number Signed-off-by: xadupre <[email protected]> * fix sparkml Signed-off-by: xadupre <[email protected]> * add one unit test Signed-off-by: xadupre <[email protected]> * fix decision tree classifier Signed-off-by: xadupre <[email protected]> * lower opset Signed-off-by: xadupre <[email protected]> * better error message Signed-off-by: xadupre <[email protected]> * remove one attribute Signed-off-by: xadupre <[email protected]> * fix Node Signed-off-by: xadupre <[email protected]> * add one unit test Signed-off-by: xadupre <[email protected]> * import Signed-off-by: xadupre <[email protected]> * disable unit test for opset < 17 Signed-off-by: xadupre <[email protected]> * sort by id Signed-off-by: xadupre <[email protected]> * fix index issues Signed-off-by: xadupre <[email protected]> * change version number Signed-off-by: xadupre <[email protected]> * fix random forest classifier for sparkml Signed-off-by: xadupre <[email protected]> * fix numbers Signed-off-by: xadupre <[email protected]> Signed-off-by: xadupre <[email protected]>
1 parent fa7da9a commit ece7e36

14 files changed

+1072
-28
lines changed

onnxmltools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
This framework converts any machine learned model into onnx format
66
which is a common language to describe any machine learned model.
77
"""
8-
__version__ = "1.11.1"
8+
__version__ = "1.11.2"
99
__author__ = "ONNX"
1010
__producer__ = "OnnxMLTools"
1111
__producer_version__ = __version__
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import numpy as np
34
from onnxconverter_common.tree_ensemble import * # noqa
5+
6+
7+
def _process_process_tree_attributes(attrs):
8+
# Spark may store attributes as range and not necessary list.
9+
# ONNX does not support this type of attribute value.
10+
update = {}
11+
wrong_types = []
12+
for k, v in attrs.items():
13+
if isinstance(v, (str, int, float, np.ndarray)):
14+
continue
15+
if isinstance(v, range):
16+
v = update[k] = list(v)
17+
if isinstance(v, list):
18+
if k in ("nodes_values", "nodes_hitrates", "nodes_featureids"):
19+
if any(map(lambda s: not isinstance(s, (float, int)), v)):
20+
v = [x if isinstance(x, (float, int)) else 0 for x in v]
21+
update[k] = v
22+
continue
23+
wrong_types.append(f"Unexpected type {type(v)} for attribute {k!r}.")
24+
if len(wrong_types) > 0:
25+
raise TypeError("Unexpected type for one or several attributes:\n" + "\n".join(wrong_types))
26+
if update:
27+
attrs.update(update)

onnxmltools/convert/sparkml/operator_converters/decision_tree_classifier.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import logging
4+
import numpy as np
35
from ...common.data_types import Int64TensorType, FloatTensorType
4-
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs, \
5-
add_tree_to_attribute_pairs
66
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
77
from ...common._registration import register_converter, register_shape_calculator
8-
from .tree_ensemble_common import save_read_sparkml_model_data, \
9-
sparkml_tree_dataset_to_sklearn
8+
from .tree_ensemble_common import (
9+
save_read_sparkml_model_data, sparkml_tree_dataset_to_sklearn,
10+
add_tree_to_attribute_pairs, get_default_tree_classifier_attribute_pairs)
11+
from .tree_helper import rewrite_ids_and_process
12+
13+
logger = logging.getLogger("onnxmltools")
1014

1115

1216
def convert_decision_tree_classifier(scope, operator, container):
@@ -15,14 +19,22 @@ def convert_decision_tree_classifier(scope, operator, container):
1519

1620
attrs = get_default_tree_classifier_attribute_pairs()
1721
attrs['name'] = scope.get_unique_operator_name(op_type)
18-
attrs["classlabels_int64s"] = range(0, op.numClasses)
22+
attrs["classlabels_int64s"] = list(range(0, op.numClasses))
1923

24+
logger.info("[convert_decision_tree_classifier] save_read_sparkml_model_data")
2025
tree_df = save_read_sparkml_model_data(operator.raw_params['SparkSession'], op)
26+
logger.info("[convert_decision_tree_classifier] sparkml_tree_dataset_to_sklearn")
2127
tree = sparkml_tree_dataset_to_sklearn(tree_df, is_classifier=True)
22-
add_tree_to_attribute_pairs(attrs, True, tree, 0, 1., 0, True)
28+
logger.info("[convert_decision_tree_classifier] add_tree_to_attribute_pairs")
29+
add_tree_to_attribute_pairs(attrs, True, tree, 0, 1., 0, leaf_weights_are_counts=True)
30+
logger.info("[convert_decision_tree_classifier] n_nodes=%d", len(attrs['nodes_nodeids']))
31+
32+
# Some values appear in an array of one element instead of a float.
33+
34+
new_attrs = rewrite_ids_and_process(attrs, logger)
2335

2436
container.add_node(op_type, operator.input_full_names, [operator.outputs[0].full_name,
25-
operator.outputs[1].full_name], op_domain='ai.onnx.ml', **attrs)
37+
operator.outputs[1].full_name], op_domain='ai.onnx.ml', **new_attrs)
2638

2739

2840
register_converter('pyspark.ml.classification.DecisionTreeClassificationModel', convert_decision_tree_classifier)

onnxmltools/convert/sparkml/operator_converters/decision_tree_regressor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
2+
import logging
33
from ...common.data_types import FloatTensorType
44
from ...common.tree_ensemble import add_tree_to_attribute_pairs, \
55
get_default_tree_regressor_attribute_pairs
66
from ...common.utils import check_input_and_output_numbers
77
from ...sparkml.operator_converters.decision_tree_classifier import save_read_sparkml_model_data
88
from ...sparkml.operator_converters.tree_ensemble_common import sparkml_tree_dataset_to_sklearn
99
from ...common._registration import register_converter, register_shape_calculator
10+
from .tree_helper import rewrite_ids_and_process
11+
12+
logger = logging.getLogger("onnxmltools")
1013

1114

1215
def convert_decision_tree_regressor(scope, operator, container):
@@ -20,9 +23,10 @@ def convert_decision_tree_regressor(scope, operator, container):
2023
tree_df = save_read_sparkml_model_data(operator.raw_params['SparkSession'], op)
2124
tree = sparkml_tree_dataset_to_sklearn(tree_df, is_classifier=False)
2225
add_tree_to_attribute_pairs(attrs, False, tree, 0, 1., 0, False)
26+
new_attrs = rewrite_ids_and_process(attrs, logger)
2327

2428
container.add_node(op_type, operator.input_full_names, operator.output_full_names,
25-
op_domain='ai.onnx.ml', **attrs)
29+
op_domain='ai.onnx.ml', **new_attrs)
2630

2731

2832
register_converter('pyspark.ml.regression.DecisionTreeRegressionModel', convert_decision_tree_regressor)

onnxmltools/convert/sparkml/operator_converters/random_forest_classifier.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import logging
34
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs, \
45
add_tree_to_attribute_pairs
56
from ...common._registration import register_converter, register_shape_calculator
67
from .tree_ensemble_common import save_read_sparkml_model_data, sparkml_tree_dataset_to_sklearn
78
from .decision_tree_classifier import calculate_decision_tree_classifier_output_shapes
9+
from .tree_helper import rewrite_ids_and_process
10+
11+
logger = logging.getLogger("onnxmltools")
812

913

1014
def convert_random_forest_classifier(scope, operator, container):
1115
op = operator.raw_operator
1216
op_type = 'TreeEnsembleClassifier'
1317

14-
attr_pairs = get_default_tree_classifier_attribute_pairs()
15-
attr_pairs['name'] = scope.get_unique_operator_name(op_type)
16-
attr_pairs['classlabels_int64s'] = range(0, op.numClasses)
18+
main_attr_pairs = get_default_tree_classifier_attribute_pairs()
19+
main_attr_pairs['name'] = scope.get_unique_operator_name(op_type)
20+
main_attr_pairs['classlabels_int64s'] = list(range(0, op.numClasses))
1721

1822
# random forest calculate the final score by averaging over all trees'
1923
# outcomes, so all trees' weights are identical.
@@ -23,13 +27,21 @@ def convert_random_forest_classifier(scope, operator, container):
2327
tree_model = op.trees[tree_id]
2428
tree_df = save_read_sparkml_model_data(operator.raw_params['SparkSession'], tree_model)
2529
tree = sparkml_tree_dataset_to_sklearn(tree_df, is_classifier=True)
30+
attr_pairs = get_default_tree_classifier_attribute_pairs()
31+
attr_pairs['name'] = scope.get_unique_operator_name(op_type)
32+
attr_pairs['classlabels_int64s'] = list(range(0, op.numClasses))
33+
2634
add_tree_to_attribute_pairs(attr_pairs, True, tree, tree_id,
2735
tree_weight, 0, True)
36+
new_attrs = rewrite_ids_and_process(attr_pairs, logger)
37+
for k, v in new_attrs.items():
38+
if isinstance(v, list) and k not in {'classlabels_int64s'}:
39+
main_attr_pairs[k].extend(v)
2840

2941
container.add_node(
3042
op_type, operator.input_full_names,
3143
[operator.outputs[0].full_name, operator.outputs[1].full_name],
32-
op_domain='ai.onnx.ml', **attr_pairs)
44+
op_domain='ai.onnx.ml', **main_attr_pairs)
3345

3446

3547
register_converter('pyspark.ml.classification.RandomForestClassificationModel', convert_random_forest_classifier)

onnxmltools/convert/sparkml/operator_converters/random_forest_regressor.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import logging
34
from ...common.tree_ensemble import add_tree_to_attribute_pairs, \
45
get_default_tree_regressor_attribute_pairs
56
from ...common._registration import register_converter, register_shape_calculator
67
from .decision_tree_classifier import save_read_sparkml_model_data
78
from .decision_tree_regressor import calculate_decision_tree_regressor_output_shapes
89
from .tree_ensemble_common import sparkml_tree_dataset_to_sklearn
10+
from .tree_helper import rewrite_ids_and_process
11+
12+
logger = logging.getLogger("onnxmltools")
913

1014

1115
def convert_random_forest_regressor(scope, operator, container):
1216
op = operator.raw_operator
1317
op_type = 'TreeEnsembleRegressor'
1418

15-
attrs = get_default_tree_regressor_attribute_pairs()
16-
attrs['name'] = scope.get_unique_operator_name(op_type)
17-
attrs['n_targets'] = 1
19+
main_attrs = get_default_tree_regressor_attribute_pairs()
20+
main_attrs['name'] = scope.get_unique_operator_name(op_type)
21+
main_attrs['n_targets'] = 1
1822

1923
# random forest calculate the final score by averaging over all trees'
2024
# outcomes, so all trees' weights are identical.
@@ -24,11 +28,18 @@ def convert_random_forest_regressor(scope, operator, container):
2428
tree_model = op.trees[tree_id]
2529
tree_df = save_read_sparkml_model_data(operator.raw_params['SparkSession'], tree_model)
2630
tree = sparkml_tree_dataset_to_sklearn(tree_df, is_classifier=False)
31+
attrs = get_default_tree_regressor_attribute_pairs()
32+
attrs['name'] = scope.get_unique_operator_name(op_type)
33+
attrs['n_targets'] = 1
2734
add_tree_to_attribute_pairs(attrs, False, tree, tree_id,
2835
tree_weight, 0, False)
36+
new_attrs = rewrite_ids_and_process(attrs, logger)
37+
for k, v in new_attrs.items():
38+
if isinstance(v, list):
39+
main_attrs[k].extend(v)
2940

3041
container.add_node(op_type, operator.input_full_names, operator.output_full_names[0],
31-
op_domain='ai.onnx.ml', **attrs)
42+
op_domain='ai.onnx.ml', **main_attrs)
3243

3344

3445
register_converter('pyspark.ml.regression.RandomForestRegressionModel', convert_random_forest_regressor)

onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,29 @@ class SparkMLTree(dict):
1313
def sparkml_tree_dataset_to_sklearn(tree_df, is_classifier):
1414
feature = []
1515
threshold = []
16-
tree_pandas = tree_df.toPandas()
16+
tree_pandas = tree_df.toPandas().sort_values("id")
1717
children_left = tree_pandas.leftChild.values.tolist()
1818
children_right = tree_pandas.rightChild.values.tolist()
19-
value = tree_pandas.impurityStats.values.tolist() if is_classifier else tree_pandas.prediction.values.tolist()
20-
split = tree_pandas.split.apply(tuple).values
21-
for item in split:
22-
feature.append(item[0])
23-
threshold.append(item[1][0] if len(item[1]) >= 1 else -1.0)
19+
ids = tree_pandas.id.values.tolist()
20+
if is_classifier:
21+
value = numpy.array(tree_pandas.impurityStats.values.tolist())
22+
else:
23+
value = tree_pandas.prediction.values.tolist()
24+
25+
for item in tree_pandas.split:
26+
if isinstance(item, dict):
27+
try:
28+
feature.append(item["featureIndex"])
29+
threshold.append(item["leftCategoriesOrThreshold"])
30+
except KeyError as e:
31+
raise RuntimeError(f"Unable to process {item}.")
32+
else:
33+
tuple_item = tuple(item)
34+
feature.append(item[0])
35+
threshold.append(item[1][0] if len(item[1]) >= 1 else -1.0)
36+
2437
tree = SparkMLTree()
38+
tree.nodes_ids = ids
2539
tree.children_left = children_left
2640
tree.children_right = children_right
2741
tree.value = numpy.asarray(value, dtype=numpy.float32)
@@ -44,3 +58,105 @@ def save_read_sparkml_model_data(spark: SparkSession, model):
4458
model.write().overwrite().save(path)
4559
df = spark.read.parquet(os.path.join(path, 'data'))
4660
return df
61+
62+
63+
def get_default_tree_classifier_attribute_pairs():
64+
attrs = {}
65+
attrs['post_transform'] = 'NONE'
66+
attrs['nodes_treeids'] = []
67+
attrs['nodes_nodeids'] = []
68+
attrs['nodes_featureids'] = []
69+
attrs['nodes_modes'] = []
70+
attrs['nodes_values'] = []
71+
attrs['nodes_truenodeids'] = []
72+
attrs['nodes_falsenodeids'] = []
73+
attrs['nodes_missing_value_tracks_true'] = []
74+
attrs['nodes_hitrates'] = []
75+
attrs['class_treeids'] = []
76+
attrs['class_nodeids'] = []
77+
attrs['class_ids'] = []
78+
attrs['class_weights'] = []
79+
return attrs
80+
81+
82+
def get_default_tree_regressor_attribute_pairs():
83+
attrs = {}
84+
attrs['post_transform'] = 'NONE'
85+
attrs['n_targets'] = 0
86+
attrs['nodes_treeids'] = []
87+
attrs['nodes_nodeids'] = []
88+
attrs['nodes_featureids'] = []
89+
attrs['nodes_modes'] = []
90+
attrs['nodes_values'] = []
91+
attrs['nodes_truenodeids'] = []
92+
attrs['nodes_falsenodeids'] = []
93+
attrs['nodes_missing_value_tracks_true'] = []
94+
attrs['nodes_hitrates'] = []
95+
attrs['target_treeids'] = []
96+
attrs['target_nodeids'] = []
97+
attrs['target_ids'] = []
98+
attrs['target_weights'] = []
99+
return attrs
100+
101+
102+
def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id, feature_id, mode, value, true_child_id,
103+
false_child_id, weights, weight_id_bias, leaf_weights_are_counts):
104+
attr_pairs['nodes_treeids'].append(tree_id)
105+
attr_pairs['nodes_nodeids'].append(node_id)
106+
attr_pairs['nodes_featureids'].append(feature_id)
107+
attr_pairs['nodes_modes'].append(mode)
108+
attr_pairs['nodes_values'].append(value)
109+
attr_pairs['nodes_truenodeids'].append(true_child_id)
110+
attr_pairs['nodes_falsenodeids'].append(false_child_id)
111+
attr_pairs['nodes_missing_value_tracks_true'].append(False)
112+
attr_pairs['nodes_hitrates'].append(1.)
113+
114+
# Add leaf information for making prediction
115+
if mode == 'LEAF':
116+
flattened_weights = weights.flatten()
117+
factor = tree_weight
118+
# If the values stored at leaves are counts of possible classes, we need convert them to probabilities by
119+
# doing a normalization.
120+
if leaf_weights_are_counts:
121+
s = sum(flattened_weights)
122+
factor /= float(s) if s != 0. else 1.
123+
flattened_weights = [w * factor for w in flattened_weights]
124+
if len(flattened_weights) == 2 and is_classifier:
125+
flattened_weights = [flattened_weights[1]]
126+
127+
# Note that attribute names for making prediction are different for classifiers and regressors
128+
if is_classifier:
129+
for i, w in enumerate(flattened_weights):
130+
attr_pairs['class_treeids'].append(tree_id)
131+
attr_pairs['class_nodeids'].append(node_id)
132+
attr_pairs['class_ids'].append(i + weight_id_bias)
133+
attr_pairs['class_weights'].append(w)
134+
else:
135+
for i, w in enumerate(flattened_weights):
136+
attr_pairs['target_treeids'].append(tree_id)
137+
attr_pairs['target_nodeids'].append(node_id)
138+
attr_pairs['target_ids'].append(i + weight_id_bias)
139+
attr_pairs['target_weights'].append(w)
140+
141+
142+
def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id, tree_weight,
143+
weight_id_bias, leaf_weights_are_counts):
144+
for i in range(tree.node_count):
145+
node_id = tree.nodes_ids[i]
146+
weight = tree.value[i]
147+
148+
if tree.children_left[i] >= 0 or tree.children_right[i] >= 0:
149+
mode = 'BRANCH_LEQ'
150+
feat_id = tree.feature[i]
151+
threshold = tree.threshold[i]
152+
left_child_id = int(tree.children_left[i])
153+
right_child_id = int(tree.children_right[i])
154+
else:
155+
mode = 'LEAF'
156+
feat_id = 0
157+
threshold = 0.
158+
left_child_id = 0
159+
right_child_id = 0
160+
161+
add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id, feat_id, mode, threshold,
162+
left_child_id, right_child_id, weight, weight_id_bias, leaf_weights_are_counts)

0 commit comments

Comments
 (0)