@@ -167,34 +167,45 @@ def statistic_(inputs):
167167 sorted_inputs = np .sort (inputs , axis = 0 )
168168 bin_edges = []
169169 for i , states_ in enumerate (states ):
170- col = inputs [:, i ]
171- uniq = np .unique (col )
170+ col = inputs [:, i ]
171+ uniq = np .unique (col )
172172
173173 # If this input has only a few unique numeric values (categorical-like),
174174 # build bin edges around unique values so we don't get empty states.
175- if uniq .size <= 5 :
176- uniq = np .sort (uniq ).astype (float )
177- if uniq .size == 1 :
178- bin_edges_ = np .array ([uniq [0 ] - 0.5 , uniq [0 ] + 0.5 ], dtype = float )
179- else :
180- gaps = np .diff (uniq )
181- margin = 0.1 * np .min (gaps ) # fixing boundaries for categorical
182-
183- # edges length = n_unique + 1
184- bin_edges_ = np .concatenate (
185- ([uniq [0 ] - margin ], uniq [:- 1 ] + margin , [uniq [- 1 ] + margin ])
186- ).astype (float )
187-
188- bin_edges .append (bin_edges_ )
189- continue
190-
191- splits = np .array_split (sorted_inputs [:, i ], states_ )
192-
193- bin_edges_ = [splits_ [0 ] for splits_ in splits ]
194- bin_edges_ .append (splits [- 1 ][- 1 ]) # last point to close the edges
195- bin_edges_ = np .array (bin_edges_ , dtype = float )
196- bin_edges_ += 1e-10 * np .linspace (0 , 1 , len (bin_edges_ ))
197- bin_edges .append (bin_edges_ )
175+ sorted_inputs = np .sort (inputs , axis = 0 )
176+ bin_edges = []
177+
178+ for i , states_ in enumerate (states ):
179+ col = inputs [:, i ]
180+ uniq = np .unique (col )
181+
182+ # Categorical-like numeric inputs
183+ if uniq .size <= 5 and states_ == uniq .size :
184+ uniq = np .sort (uniq ).astype (float )
185+
186+ if uniq .size == 1 :
187+ bin_edges_ = np .array (
188+ [uniq [0 ] - 0.5 , uniq [0 ] + 0.5 ], dtype = float
189+ )
190+ else :
191+ gaps = np .diff (uniq )
192+ margin = 0.1 * np .min (gaps )
193+
194+ bin_edges_ = np .concatenate (
195+ ([uniq [0 ] - margin ], uniq [:- 1 ] + margin , [uniq [- 1 ] + margin ])
196+ ).astype (float )
197+
198+ bin_edges .append (bin_edges_ )
199+ continue
200+
201+ # Default: equal-number-of-samples bins
202+ splits = np .array_split (sorted_inputs [:, i ], states_ )
203+ bin_edges_ = [splits_ [0 ] for splits_ in splits ]
204+ bin_edges_ .append (splits [- 1 ][- 1 ]) # last point to close the edges
205+ bin_edges_ = np .array (bin_edges_ , dtype = float )
206+ bin_edges_ += 1e-10 * np .linspace (0 , 1 , len (bin_edges_ ))
207+ bin_edges .append (bin_edges_ )
208+
198209
199210 res = stats .binned_statistic_dd (
200211 inputs , values = output , statistic = statistic_ , bins = bin_edges
0 commit comments