@@ -167,10 +167,32 @@ def statistic_(inputs):
167167 sorted_inputs = np .sort (inputs , axis = 0 )
168168 bin_edges = []
169169 for i , states_ in enumerate (states ):
170+ col = inputs [:, i ]
171+ uniq = np .unique (col )
172+
173+ # If this input has only a few unique numeric values (categorical-like),
174+ # build bin edges around unique values so we don't get empty states.
175+ if uniq .size <= 5 :
176+ uniq = np .sort (uniq ).astype (float )
177+ if uniq .size == 1 :
178+ bin_edges_ = np .array ([uniq [0 ] - 0.5 , uniq [0 ] + 0.5 ], dtype = float )
179+ else :
180+ gaps = np .diff (uniq )
181+ margin = 0.1 * np .min (gaps ) # fixing boundaries for categorical
182+
183+ # edges length = n_unique + 1
184+ bin_edges_ = np .concatenate (
185+ ([uniq [0 ] - margin ], uniq [:- 1 ] + margin , [uniq [- 1 ] + margin ])
186+ ).astype (float )
187+
188+ bin_edges .append (bin_edges_ )
189+ continue
190+
170191 splits = np .array_split (sorted_inputs [:, i ], states_ )
192+
171193 bin_edges_ = [splits_ [0 ] for splits_ in splits ]
172194 bin_edges_ .append (splits [- 1 ][- 1 ]) # last point to close the edges
173- # bin_edges_ = np.unique (bin_edges_) # remove duplicate points, sorted
195+ bin_edges_ = np .array (bin_edges_ , dtype = float )
174196 bin_edges_ += 1e-10 * np .linspace (0 , 1 , len (bin_edges_ ))
175197 bin_edges .append (bin_edges_ )
176198
0 commit comments