Skip to content

Commit 479fae6

Browse files
authored
fixing categorical states bug #36
1 parent 337aec8 commit 479fae6

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/simdec/decomposition.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,32 @@ def statistic_(inputs):
167167
sorted_inputs = np.sort(inputs, axis=0)
168168
bin_edges = []
169169
for i, states_ in enumerate(states):
170+
col = inputs[:, i]
171+
uniq = np.unique(col)
172+
173+
# If this input has only a few unique numeric values (categorical-like),
174+
# build bin edges around unique values so we don't get empty states.
175+
if uniq.size <= 5:
176+
uniq = np.sort(uniq).astype(float)
177+
if uniq.size == 1:
178+
bin_edges_ = np.array([uniq[0] - 0.5, uniq[0] + 0.5], dtype=float)
179+
else:
180+
gaps = np.diff(uniq)
181+
margin = 0.1 * np.min(gaps) # fixing boundaries for categorical
182+
183+
# edges length = n_unique + 1
184+
bin_edges_ = np.concatenate(
185+
([uniq[0] - margin], uniq[:-1] + margin, [uniq[-1] + margin])
186+
).astype(float)
187+
188+
bin_edges.append(bin_edges_)
189+
continue
190+
170191
splits = np.array_split(sorted_inputs[:, i], states_)
192+
171193
bin_edges_ = [splits_[0] for splits_ in splits]
172194
bin_edges_.append(splits[-1][-1]) # last point to close the edges
173-
# bin_edges_ = np.unique(bin_edges_) # remove duplicate points, sorted
195+
bin_edges_ = np.array(bin_edges_, dtype=float)
174196
bin_edges_ += 1e-10 * np.linspace(0, 1, len(bin_edges_))
175197
bin_edges.append(bin_edges_)
176198

0 commit comments

Comments
 (0)