Skip to content

Commit 7a0b90e

Browse files
committed
Add support to HistGradientBoost (#2)
1 parent 6d8bf2f commit 7a0b90e

File tree

7 files changed

+751
-17006
lines changed

7 files changed

+751
-17006
lines changed

examples/CustomTreeExample.ipynb

Lines changed: 80 additions & 64 deletions
Large diffs are not rendered by default.

examples/GradientBoostExample.ipynb

Lines changed: 147 additions & 117 deletions
Large diffs are not rendered by default.

examples/LightGBMAndXGBoostExample.ipynb

Lines changed: 390 additions & 180 deletions
Large diffs are not rendered by default.

examples/MainExample.ipynb

Lines changed: 31 additions & 16632 deletions
Large diffs are not rendered by default.

supertree/js/supertree.min.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

supertree/node.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
2+
import pandas as pd
33

44
class Node:
55
def __init__(self, feature, threshold, impurity, samples,
@@ -22,16 +22,38 @@ def __init__(self, feature, threshold, impurity, samples,
2222

2323
def to_dict(self):
2424
def convert(value):
25-
if isinstance(value, (np.int64, np.int32)):
25+
if isinstance(
26+
value,
27+
(
28+
np.int8,
29+
np.int16,
30+
np.int32,
31+
np.int64,
32+
np.uint8,
33+
np.uint16,
34+
np.uint32,
35+
np.uint64,
36+
),
37+
):
2638
return int(value)
39+
if isinstance(value, (np.float16, np.float32, np.float64, np.float128)):
40+
return float(value)
41+
if isinstance(value, np.ndarray):
42+
return [convert(item) for item in value.tolist()]
43+
if isinstance(value, list):
44+
return [convert(item) for item in value]
45+
if isinstance(value, pd.Series):
46+
return [convert(item) for item in value.tolist()]
47+
if isinstance(value, dict):
48+
return {key: convert(val) for key, val in value.items()}
2749
if isinstance(value, np.ndarray):
2850
return value.tolist()
2951
return value
3052

3153
node_dict = {
3254
"feature": int(self.feature) if isinstance(self.feature, np.longlong) else self.feature,
3355
"threshold": self.threshold,
34-
"impurity": self.impurity,
56+
"impurity": convert(self.impurity),
3557
"samples": int(self.samples) if isinstance(self.samples, np.longlong) else self.samples,
3658
"class_distribution": [convert(val) for val in self.class_distribution],
3759
"treeclass": convert(self.treeclass),

supertree/supertree.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def __init__(
4747
"XGBRegressor",
4848
"XGBRFClassifier",
4949
"XGBRFRegressor",
50-
"ModelLoader"
50+
"ModelLoader",
51+
"HistGradientBoostingClassifier",
52+
"HistGradientBoostingRegressor",
5153
]
5254

5355
if model.__class__.__name__ not in valid_model_classes:
@@ -174,17 +176,21 @@ def __init__(
174176

175177
self.licence_key = licence_key
176178

177-
def show_tree(self, which_tree=0, start_depth=5):
179+
def show_tree(self, which_tree=0, which_iteration=0, start_depth=5):
178180
"""
179181
Displaying model HTMl template and create json tree model.
180182
"""
181183
if not isinstance(which_tree, int):
182184
raise TypeError("Invalid which_tree type. Expected an integer.")
183185

186+
if not isinstance(which_iteration, int):
187+
raise TypeError("Invalid which_iteration type. Expected an integer.")
188+
184189
if not isinstance(start_depth, int):
185190
raise TypeError("Invalid start_depth type. Expected an integer.")
186191

187192
self.which_tree = which_tree
193+
self.which_iteration = which_iteration
188194

189195
if self.model_type == "uknown_model":
190196
return 0
@@ -204,7 +210,7 @@ def show_tree(self, which_tree=0, start_depth=5):
204210
display(HTML(templatehtml.get_d3_html(
205211
combined_data_str,start_depth, self.licence_key)))
206212

207-
def save_html(self, filename="output", which_tree=0,start_depth=5):
213+
def save_html(self, filename="output", which_tree=0, which_iteration=0,start_depth=5):
208214
"""
209215
Saving HTML file and create json tree model.
210216
"""
@@ -218,11 +224,14 @@ def save_html(self, filename="output", which_tree=0,start_depth=5):
218224
if not isinstance(start_depth, int):
219225
raise TypeError("Invalid start_depth type. Expected an integer.")
220226

227+
if not isinstance(which_iteration, int):
228+
raise TypeError("Invalid which_iteration type. Expected an integer.")
229+
221230
if filename is not None and not isinstance(filename, str):
222231
raise TypeError("Invalid filename type. Expected a string.")
223232

224233
self.which_tree = which_tree
225-
234+
self.which_iteration = which_iteration
226235
d3script = """
227236
<script src="https://cdn.jsdelivr.net/npm/d3@7" charset="utf-8"></script>
228237
<script src="https://cdn.jsdelivr.net/npm/tweetnacl@1.0.3/nacl.min.js"></script>
@@ -312,7 +321,7 @@ def create_node_dfs(self, node_index, left_right, threshold, feature, x_axis):
312321
node.start_end_x_axis,
313322
)
314323

315-
def save_json_tree(self, filename="treedata", which_tree=0):
324+
def save_json_tree(self, filename="treedata", which_tree=0, which_iteration = 0):
316325
"""
317326
Save tree to json tree.
318327
"""
@@ -325,7 +334,11 @@ def save_json_tree(self, filename="treedata", which_tree=0):
325334
if filename is not None and not isinstance(filename, str):
326335
raise TypeError("Invalid filename type. Expected a string.")
327336

337+
if not isinstance(which_iteration, int):
338+
raise TypeError("Invalid which_iteration type. Expected an integer.")
339+
328340
self.which_tree = which_tree
341+
self.which_iteration = which_iteration
329342

330343
combined_data_str = self.get_combined_data()
331344

@@ -344,6 +357,7 @@ def which_model(self):
344357
"LGBMClassifier",
345358
"XGBClassifier",
346359
"XGBRFClassifier",
360+
"HistGradientBoostingClassifier",
347361
):
348362
return "classification"
349363
elif self.model_name in (
@@ -355,6 +369,7 @@ def which_model(self):
355369
"LGBMRegressor",
356370
"XGBRegressor",
357371
"XGBRFRegressor",
372+
"HistGradientBoostingRegressor",
358373
):
359374
return "regression"
360375

@@ -421,10 +436,10 @@ def convert_model_to_dict_array(self):
421436
"GradientBoostingRegressor",
422437
"GradientBoostingClassifier",
423438
):
424-
if (0 <= self.which_tree < len(self.model.estimators_) and 0 <= self.which_iteration < len(self.model.estimators_[self.which_tree]) ):
425-
super_tree = self.model.estimators_[self.which_tree, self.which_iteration].tree_
439+
if (0 <= self.which_iteration < len(self.model.estimators_) and 0 <= self.which_tree < len(self.model.estimators_[self.which_tree]) ):
440+
super_tree = self.model.estimators_[self.which_iteration, self.which_tree].tree_
426441
else:
427-
raise IndexError("Wartość 'which_tree' lub 'which_iteration' jest poza zakresem dostępnych wartości.")
442+
raise IndexError("Value of 'which_tree' or 'which_iteration' is out of range.")
428443
else:
429444
super_tree = self.model.tree_
430445

@@ -489,7 +504,6 @@ def convert_model_to_dict_array(self):
489504
else:
490505
model_dict = self.model.get_dump(with_stats=True, dump_format="json")
491506

492-
493507
json_tree = model_dict[self.which_tree]
494508

495509
if 0 <= self.which_tree < len(model_dict):
@@ -501,6 +515,16 @@ def convert_model_to_dict_array(self):
501515
if model_name in ("ModelLoader"):
502516
self.node_list = self.model.model_dict
503517

518+
if model_name in ("HistGradientBoostingClassifier", "HistGradientBoostingRegressor"):
519+
520+
if not (0 <= self.which_iteration < len(self.model._predictors)):
521+
raise IndexError(f"which_iteration {self.which_iteration} is out of range. Valid range is 0 to {len(self.model._predictors) - 1}.")
522+
523+
if not (0 <= self.which_tree < len(self.model._predictors[self.which_iteration])):
524+
raise IndexError(f"which_tree {self.which_tree} is out of range for iteration {self.which_iteration}. Valid range is 0 to {len(self.model._predictors[self.which_iteration]) - 1}.")
525+
nodes = self.model._predictors[self.which_iteration][self.which_tree].nodes
526+
self.collect_node_info_histgb(nodes)
527+
504528

505529
def collect_node_info_lgbm(self, node, depth=0):
506530
node_index = len(self.node_list)
@@ -673,3 +697,47 @@ def collect_node_info_xgboost(self, node, depth=0):
673697
self.node_list.append(node_info)
674698

675699
return node_index
700+
701+
def collect_node_info_histgb(self, nodes):
702+
for i, node in enumerate(nodes):
703+
feature_index = int(node[2])
704+
threshold = node[3]
705+
left_child = int(node[5])
706+
right_child = int(node[6])
707+
samples = int(node[1])
708+
impurity = node[8]
709+
if( self.model_type not in ("nodata")):
710+
if(self.model_name=="HistGradientBoostingRegressor" ):
711+
class_dist = [[10,10,10]]
712+
predicted_class = self.target_names[0]
713+
else:
714+
class_dist = None
715+
predicted_class_index = node[9]
716+
predicted_class = self.target_names[predicted_class_index]
717+
if(self.model_type =="nodata"):
718+
class_dist = "No Data"
719+
predicted_class = node[9]
720+
721+
if left_child == 0:
722+
left_child = -1
723+
if right_child == 0:
724+
right_child = -1
725+
726+
if left_child == -1 and right_child == -1:
727+
is_leaf = True
728+
else:
729+
is_leaf = False
730+
731+
node_info = {
732+
"index": i,
733+
"feature": feature_index,
734+
"impurity": impurity,
735+
"threshold": threshold,
736+
"class_distribution": class_dist,
737+
"predicted_class": predicted_class,
738+
"samples": samples,
739+
"is_leaf": is_leaf,
740+
"left_child_index": left_child,
741+
"right_child_index": right_child,
742+
}
743+
self.node_list.append(node_info)

0 commit comments

Comments
 (0)