Skip to content

Commit f5d9261

Browse files
DataFrameGroupBy.agg with nan results error Fixed
1 parent dcb5494 commit f5d9261

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

pandas/core/_numba/kernels/min_max_.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def sliding_min_max(
3131
nobs = 0
3232
output = np.empty(N, dtype=result_dtype)
3333
na_pos = []
34-
# Use deque once numba supports it
35-
# https://github.com/numba/numba/issues/7417
3634
Q: list = []
3735
W: list = []
3836
for i in range(N):
@@ -46,36 +44,28 @@ def sliding_min_max(
4644
ai = values[k]
4745
if not np.isnan(ai):
4846
nobs += 1
49-
elif is_max:
50-
ai = -np.inf
5147
else:
52-
ai = np.inf
53-
# Discard previous entries if we find new min or max
48+
ai = -np.inf if is_max else np.inf
5449
if is_max:
55-
while Q and ((ai >= values[Q[-1]]) or values[Q[-1]] != values[Q[-1]]):
50+
while Q and ((ai >= values[Q[-1]]) or np.isnan(values[Q[-1]])):
5651
Q.pop()
5752
else:
58-
while Q and ((ai <= values[Q[-1]]) or values[Q[-1]] != values[Q[-1]]):
53+
while Q and ((ai <= values[Q[-1]]) or np.isnan(values[Q[-1]])):
5954
Q.pop()
6055
Q.append(k)
6156
W.append(k)
6257

63-
# Discard entries outside and left of current window
6458
while Q and Q[0] <= start[i] - 1:
6559
Q.pop(0)
6660
while W and W[0] <= start[i] - 1:
6761
if not np.isnan(values[W[0]]):
6862
nobs -= 1
6963
W.pop(0)
7064

71-
# Save output based on index in input value array
7265
if Q and curr_win_size > 0 and nobs >= min_periods:
7366
output[i] = values[Q[0]]
7467
else:
75-
if values.dtype.kind != "i":
76-
output[i] = np.nan
77-
else:
78-
na_pos.append(i)
68+
output[i] = np.nan
7969

8070
return output, na_pos
8171

@@ -100,27 +90,31 @@ def grouped_min_max(
10090
if lab < 0:
10191
continue
10292

103-
if values.dtype.kind == "i" or not np.isnan(val):
93+
if not np.isnan(val):
10494
nobs[lab] += 1
10595
else:
106-
# NaN value cannot be a min/max value
10796
continue
10897

10998
if nobs[lab] == 1:
110-
# First element in group, set output equal to this
11199
output[lab] = val
112-
continue
113-
114-
if is_max:
115-
if val > output[lab]:
116-
output[lab] = val
117100
else:
118-
if val < output[lab]:
119-
output[lab] = val
101+
if is_max:
102+
if val > output[lab]:
103+
output[lab] = val
104+
else:
105+
if val < output[lab]:
106+
output[lab] = val
120107

121-
# Set labels that don't satisfy min_periods as np.nan
122108
for lab, count in enumerate(nobs):
123109
if count < min_periods:
124-
na_pos.append(lab)
110+
output[lab] = np.nan
125111

126112
return output, na_pos
113+
114+
# Example usage:
115+
if __name__ == "__main__":
116+
import pandas as pd
117+
118+
s = pd.Series([np.nan, 0, 1], dtype="Float64")
119+
print((s / s).max()) # <NA>
120+
print((s / s).groupby([9, 9, 9]).max().iat[0]) # <NA>

0 commit comments

Comments
 (0)