Skip to content

Commit 0298ca4

Browse files
author
niklasmueboe
committed
more robust kde_nd
1 parent 0661d0a commit 0298ca4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ovrlpy/_kde.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def kde_nd(
8282
"""
8383
assert len(coordinates) >= 1
8484

85+
mins = tuple(int(floor(c.min())) for c in coordinates)
86+
maxs = tuple(int(floor(c.max() + 1)) for c in coordinates)
87+
assert all(min >= 0 for min in mins)
88+
8589
n = coordinates[0].shape[0]
8690
if not all(x.shape[0] == n for x in coordinates[1:]):
8791
raise ValueError("All coordinates must have the same number of rows")
@@ -93,11 +97,9 @@ def kde_nd(
9397
return np.zeros(size, dtype=dtype)
9498

9599
if size is None:
96-
size = tuple(int(floor(c.max() + 1)) for c in coordinates)
100+
size = maxs
97101

98-
dim_bins = [
99-
np.arange(int(c.min()), int(floor(c.max() + 1)) + 1) for c in coordinates
100-
]
102+
dim_bins = [np.arange(min, max + 1) for min, max in zip(mins, maxs)]
101103
counts, bins = np.histogramdd(coordinates, bins=dim_bins)
102104
kde = _kde(counts, bandwidth, dtype=dtype, **kwargs)
103105

0 commit comments

Comments
 (0)