@@ -13,15 +13,29 @@ class SparkMLTree(dict):
1313def sparkml_tree_dataset_to_sklearn (tree_df , is_classifier ):
1414 feature = []
1515 threshold = []
16- tree_pandas = tree_df .toPandas ()
16+ tree_pandas = tree_df .toPandas (). sort_values ( "id" )
1717 children_left = tree_pandas .leftChild .values .tolist ()
1818 children_right = tree_pandas .rightChild .values .tolist ()
19- value = tree_pandas .impurityStats .values .tolist () if is_classifier else tree_pandas .prediction .values .tolist ()
20- split = tree_pandas .split .apply (tuple ).values
21- for item in split :
22- feature .append (item [0 ])
23- threshold .append (item [1 ][0 ] if len (item [1 ]) >= 1 else - 1.0 )
19+ ids = tree_pandas .id .values .tolist ()
20+ if is_classifier :
21+ value = numpy .array (tree_pandas .impurityStats .values .tolist ())
22+ else :
23+ value = tree_pandas .prediction .values .tolist ()
24+
25+ for item in tree_pandas .split :
26+ if isinstance (item , dict ):
27+ try :
28+ feature .append (item ["featureIndex" ])
29+ threshold .append (item ["leftCategoriesOrThreshold" ])
30+ except KeyError as e :
31+ raise RuntimeError (f"Unable to process { item } ." )
32+ else :
33+ tuple_item = tuple (item )
34+ feature .append (item [0 ])
35+ threshold .append (item [1 ][0 ] if len (item [1 ]) >= 1 else - 1.0 )
36+
2437 tree = SparkMLTree ()
38+ tree .nodes_ids = ids
2539 tree .children_left = children_left
2640 tree .children_right = children_right
2741 tree .value = numpy .asarray (value , dtype = numpy .float32 )
@@ -44,3 +58,105 @@ def save_read_sparkml_model_data(spark: SparkSession, model):
4458 model .write ().overwrite ().save (path )
4559 df = spark .read .parquet (os .path .join (path , 'data' ))
4660 return df
61+
62+
63+ def get_default_tree_classifier_attribute_pairs ():
64+ attrs = {}
65+ attrs ['post_transform' ] = 'NONE'
66+ attrs ['nodes_treeids' ] = []
67+ attrs ['nodes_nodeids' ] = []
68+ attrs ['nodes_featureids' ] = []
69+ attrs ['nodes_modes' ] = []
70+ attrs ['nodes_values' ] = []
71+ attrs ['nodes_truenodeids' ] = []
72+ attrs ['nodes_falsenodeids' ] = []
73+ attrs ['nodes_missing_value_tracks_true' ] = []
74+ attrs ['nodes_hitrates' ] = []
75+ attrs ['class_treeids' ] = []
76+ attrs ['class_nodeids' ] = []
77+ attrs ['class_ids' ] = []
78+ attrs ['class_weights' ] = []
79+ return attrs
80+
81+
82+ def get_default_tree_regressor_attribute_pairs ():
83+ attrs = {}
84+ attrs ['post_transform' ] = 'NONE'
85+ attrs ['n_targets' ] = 0
86+ attrs ['nodes_treeids' ] = []
87+ attrs ['nodes_nodeids' ] = []
88+ attrs ['nodes_featureids' ] = []
89+ attrs ['nodes_modes' ] = []
90+ attrs ['nodes_values' ] = []
91+ attrs ['nodes_truenodeids' ] = []
92+ attrs ['nodes_falsenodeids' ] = []
93+ attrs ['nodes_missing_value_tracks_true' ] = []
94+ attrs ['nodes_hitrates' ] = []
95+ attrs ['target_treeids' ] = []
96+ attrs ['target_nodeids' ] = []
97+ attrs ['target_ids' ] = []
98+ attrs ['target_weights' ] = []
99+ return attrs
100+
101+
102+ def add_node (attr_pairs , is_classifier , tree_id , tree_weight , node_id , feature_id , mode , value , true_child_id ,
103+ false_child_id , weights , weight_id_bias , leaf_weights_are_counts ):
104+ attr_pairs ['nodes_treeids' ].append (tree_id )
105+ attr_pairs ['nodes_nodeids' ].append (node_id )
106+ attr_pairs ['nodes_featureids' ].append (feature_id )
107+ attr_pairs ['nodes_modes' ].append (mode )
108+ attr_pairs ['nodes_values' ].append (value )
109+ attr_pairs ['nodes_truenodeids' ].append (true_child_id )
110+ attr_pairs ['nodes_falsenodeids' ].append (false_child_id )
111+ attr_pairs ['nodes_missing_value_tracks_true' ].append (False )
112+ attr_pairs ['nodes_hitrates' ].append (1. )
113+
114+ # Add leaf information for making prediction
115+ if mode == 'LEAF' :
116+ flattened_weights = weights .flatten ()
117+ factor = tree_weight
118+ # If the values stored at leaves are counts of possible classes, we need convert them to probabilities by
119+ # doing a normalization.
120+ if leaf_weights_are_counts :
121+ s = sum (flattened_weights )
122+ factor /= float (s ) if s != 0. else 1.
123+ flattened_weights = [w * factor for w in flattened_weights ]
124+ if len (flattened_weights ) == 2 and is_classifier :
125+ flattened_weights = [flattened_weights [1 ]]
126+
127+ # Note that attribute names for making prediction are different for classifiers and regressors
128+ if is_classifier :
129+ for i , w in enumerate (flattened_weights ):
130+ attr_pairs ['class_treeids' ].append (tree_id )
131+ attr_pairs ['class_nodeids' ].append (node_id )
132+ attr_pairs ['class_ids' ].append (i + weight_id_bias )
133+ attr_pairs ['class_weights' ].append (w )
134+ else :
135+ for i , w in enumerate (flattened_weights ):
136+ attr_pairs ['target_treeids' ].append (tree_id )
137+ attr_pairs ['target_nodeids' ].append (node_id )
138+ attr_pairs ['target_ids' ].append (i + weight_id_bias )
139+ attr_pairs ['target_weights' ].append (w )
140+
141+
142+ def add_tree_to_attribute_pairs (attr_pairs , is_classifier , tree , tree_id , tree_weight ,
143+ weight_id_bias , leaf_weights_are_counts ):
144+ for i in range (tree .node_count ):
145+ node_id = tree .nodes_ids [i ]
146+ weight = tree .value [i ]
147+
148+ if tree .children_left [i ] >= 0 or tree .children_right [i ] >= 0 :
149+ mode = 'BRANCH_LEQ'
150+ feat_id = tree .feature [i ]
151+ threshold = tree .threshold [i ]
152+ left_child_id = int (tree .children_left [i ])
153+ right_child_id = int (tree .children_right [i ])
154+ else :
155+ mode = 'LEAF'
156+ feat_id = 0
157+ threshold = 0.
158+ left_child_id = 0
159+ right_child_id = 0
160+
161+ add_node (attr_pairs , is_classifier , tree_id , tree_weight , node_id , feat_id , mode , threshold ,
162+ left_child_id , right_child_id , weight , weight_id_bias , leaf_weights_are_counts )
0 commit comments