@@ -47,7 +47,9 @@ def __init__(
4747 "XGBRegressor" ,
4848 "XGBRFClassifier" ,
4949 "XGBRFRegressor" ,
50- "ModelLoader"
50+ "ModelLoader" ,
51+ "HistGradientBoostingClassifier" ,
52+ "HistGradientBoostingRegressor" ,
5153 ]
5254
5355 if model .__class__ .__name__ not in valid_model_classes :
@@ -174,17 +176,21 @@ def __init__(
174176
175177 self .licence_key = licence_key
176178
177- def show_tree (self , which_tree = 0 , start_depth = 5 ):
179+ def show_tree (self , which_tree = 0 , which_iteration = 0 , start_depth = 5 ):
178180 """
179181 Displaying model HTMl template and create json tree model.
180182 """
181183 if not isinstance (which_tree , int ):
182184 raise TypeError ("Invalid which_tree type. Expected an integer." )
183185
186+ if not isinstance (which_iteration , int ):
187+ raise TypeError ("Invalid which_iteration type. Expected an integer." )
188+
184189 if not isinstance (start_depth , int ):
185190 raise TypeError ("Invalid start_depth type. Expected an integer." )
186191
187192 self .which_tree = which_tree
193+ self .which_iteration = which_iteration
188194
189195 if self .model_type == "uknown_model" :
190196 return 0
@@ -204,7 +210,7 @@ def show_tree(self, which_tree=0, start_depth=5):
204210 display (HTML (templatehtml .get_d3_html (
205211 combined_data_str ,start_depth , self .licence_key )))
206212
207- def save_html (self , filename = "output" , which_tree = 0 ,start_depth = 5 ):
213+ def save_html (self , filename = "output" , which_tree = 0 , which_iteration = 0 , start_depth = 5 ):
208214 """
209215 Saving HTML file and create json tree model.
210216 """
@@ -218,11 +224,14 @@ def save_html(self, filename="output", which_tree=0,start_depth=5):
218224 if not isinstance (start_depth , int ):
219225 raise TypeError ("Invalid start_depth type. Expected an integer." )
220226
227+ if not isinstance (which_iteration , int ):
228+ raise TypeError ("Invalid which_iteration type. Expected an integer." )
229+
221230 if filename is not None and not isinstance (filename , str ):
222231 raise TypeError ("Invalid filename type. Expected a string." )
223232
224233 self .which_tree = which_tree
225-
234+ self . which_iteration = which_iteration
226235 d3script = """
227236 <script src="https://cdn.jsdelivr.net/npm/d3@7" charset="utf-8"></script>
228237 <script src="https://cdn.jsdelivr.net/npm/tweetnacl@1.0.3/nacl.min.js"></script>
@@ -312,7 +321,7 @@ def create_node_dfs(self, node_index, left_right, threshold, feature, x_axis):
312321 node .start_end_x_axis ,
313322 )
314323
315- def save_json_tree (self , filename = "treedata" , which_tree = 0 ):
324+ def save_json_tree (self , filename = "treedata" , which_tree = 0 , which_iteration = 0 ):
316325 """
317326 Save tree to json tree.
318327 """
@@ -325,7 +334,11 @@ def save_json_tree(self, filename="treedata", which_tree=0):
325334 if filename is not None and not isinstance (filename , str ):
326335 raise TypeError ("Invalid filename type. Expected a string." )
327336
337+ if not isinstance (which_iteration , int ):
338+ raise TypeError ("Invalid which_iteration type. Expected an integer." )
339+
328340 self .which_tree = which_tree
341+ self .which_iteration = which_iteration
329342
330343 combined_data_str = self .get_combined_data ()
331344
@@ -344,6 +357,7 @@ def which_model(self):
344357 "LGBMClassifier" ,
345358 "XGBClassifier" ,
346359 "XGBRFClassifier" ,
360+ "HistGradientBoostingClassifier" ,
347361 ):
348362 return "classification"
349363 elif self .model_name in (
@@ -355,6 +369,7 @@ def which_model(self):
355369 "LGBMRegressor" ,
356370 "XGBRegressor" ,
357371 "XGBRFRegressor" ,
372+ "HistGradientBoostingRegressor" ,
358373 ):
359374 return "regression"
360375
@@ -421,10 +436,10 @@ def convert_model_to_dict_array(self):
421436 "GradientBoostingRegressor" ,
422437 "GradientBoostingClassifier" ,
423438 ):
424- if (0 <= self .which_tree < len (self .model .estimators_ ) and 0 <= self .which_iteration < len (self .model .estimators_ [self .which_tree ]) ):
425- super_tree = self .model .estimators_ [self .which_tree , self .which_iteration ].tree_
439+ if (0 <= self .which_iteration < len (self .model .estimators_ ) and 0 <= self .which_tree < len (self .model .estimators_ [self .which_tree ]) ):
440+ super_tree = self .model .estimators_ [self .which_iteration , self .which_tree ].tree_
426441 else :
427- raise IndexError ("Wartość 'which_tree' lub 'which_iteration' jest poza zakresem dostępnych wartości ." )
442+ raise IndexError ("Value of 'which_tree' or 'which_iteration' is out of range ." )
428443 else :
429444 super_tree = self .model .tree_
430445
@@ -489,7 +504,6 @@ def convert_model_to_dict_array(self):
489504 else :
490505 model_dict = self .model .get_dump (with_stats = True , dump_format = "json" )
491506
492-
493507 json_tree = model_dict [self .which_tree ]
494508
495509 if 0 <= self .which_tree < len (model_dict ):
@@ -501,6 +515,16 @@ def convert_model_to_dict_array(self):
501515 if model_name in ("ModelLoader" ):
502516 self .node_list = self .model .model_dict
503517
518+ if model_name in ("HistGradientBoostingClassifier" , "HistGradientBoostingRegressor" ):
519+
520+ if not (0 <= self .which_iteration < len (self .model ._predictors )):
521+ raise IndexError (f"which_iteration { self .which_iteration } is out of range. Valid range is 0 to { len (self .model ._predictors ) - 1 } ." )
522+
523+ if not (0 <= self .which_tree < len (self .model ._predictors [self .which_iteration ])):
524+ raise IndexError (f"which_tree { self .which_tree } is out of range for iteration { self .which_iteration } . Valid range is 0 to { len (self .model ._predictors [self .which_iteration ]) - 1 } ." )
525+ nodes = self .model ._predictors [self .which_iteration ][self .which_tree ].nodes
526+ self .collect_node_info_histgb (nodes )
527+
504528
505529 def collect_node_info_lgbm (self , node , depth = 0 ):
506530 node_index = len (self .node_list )
@@ -673,3 +697,47 @@ def collect_node_info_xgboost(self, node, depth=0):
673697 self .node_list .append (node_info )
674698
675699 return node_index
700+
701+ def collect_node_info_histgb (self , nodes ):
702+ for i , node in enumerate (nodes ):
703+ feature_index = int (node [2 ])
704+ threshold = node [3 ]
705+ left_child = int (node [5 ])
706+ right_child = int (node [6 ])
707+ samples = int (node [1 ])
708+ impurity = node [8 ]
709+ if ( self .model_type not in ("nodata" )):
710+ if (self .model_name == "HistGradientBoostingRegressor" ):
711+ class_dist = [[10 ,10 ,10 ]]
712+ predicted_class = self .target_names [0 ]
713+ else :
714+ class_dist = None
715+ predicted_class_index = node [9 ]
716+ predicted_class = self .target_names [predicted_class_index ]
717+ if (self .model_type == "nodata" ):
718+ class_dist = "No Data"
719+ predicted_class = node [9 ]
720+
721+ if left_child == 0 :
722+ left_child = - 1
723+ if right_child == 0 :
724+ right_child = - 1
725+
726+ if left_child == - 1 and right_child == - 1 :
727+ is_leaf = True
728+ else :
729+ is_leaf = False
730+
731+ node_info = {
732+ "index" : i ,
733+ "feature" : feature_index ,
734+ "impurity" : impurity ,
735+ "threshold" : threshold ,
736+ "class_distribution" : class_dist ,
737+ "predicted_class" : predicted_class ,
738+ "samples" : samples ,
739+ "is_leaf" : is_leaf ,
740+ "left_child_index" : left_child ,
741+ "right_child_index" : right_child ,
742+ }
743+ self .node_list .append (node_info )
0 commit comments