Skip to content

Commit dbed131

Browse files
authored
REF: de-duplicate min_count checking in groupby.pyx (#51066)
* REF: de-duplicate groupby.pyx nan-filling * share more * share more * remove comment * share more * inline * catch complex cases
1 parent 37c9523 commit dbed131

File tree

1 file changed

+81
-133
lines changed

1 file changed

+81
-133
lines changed

pandas/_libs/groupby.pyx

Lines changed: 81 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -747,31 +747,9 @@ def group_sum(
747747
compensation[lab, j] = t - sumx[lab, j] - y
748748
sumx[lab, j] = t
749749

750-
for i in range(ncounts):
751-
for j in range(K):
752-
if nobs[i, j] < min_count:
753-
# if we are integer dtype, not is_datetimelike, and
754-
# not uses_mask, then getting here implies that
755-
# counts[i] < min_count, which means we will
756-
# be cast to float64 and masked at the end
757-
# of WrappedCythonOp._call_cython_op. So we can safely
758-
# set a placeholder value in out[i, j].
759-
if uses_mask:
760-
result_mask[i, j] = True
761-
elif (
762-
sum_t is float32_t
763-
or sum_t is float64_t
764-
or sum_t is complex64_t
765-
):
766-
out[i, j] = NAN
767-
elif sum_t is int64_t:
768-
out[i, j] = NPY_NAT
769-
else:
770-
# placeholder, see above
771-
out[i, j] = 0
772-
773-
else:
774-
out[i, j] = sumx[i, j]
750+
_check_below_mincount(
751+
out, uses_mask, result_mask, ncounts, K, nobs, min_count, sumx
752+
)
775753

776754

777755
@cython.wraparound(False)
@@ -823,23 +801,9 @@ def group_prod(
823801
nobs[lab, j] += 1
824802
prodx[lab, j] *= val
825803

826-
for i in range(ncounts):
827-
for j in range(K):
828-
if nobs[i, j] < min_count:
829-
830-
# else case is not possible
831-
if uses_mask:
832-
result_mask[i, j] = True
833-
# Be deterministic, out was initialized as empty
834-
out[i, j] = 0
835-
elif int64float_t is float32_t or int64float_t is float64_t:
836-
out[i, j] = NAN
837-
else:
838-
# we only get here when < mincount which gets handled later
839-
pass
840-
841-
else:
842-
out[i, j] = prodx[i, j]
804+
_check_below_mincount(
805+
out, uses_mask, result_mask, ncounts, K, nobs, min_count, prodx
806+
)
843807

844808

845809
@cython.wraparound(False)
@@ -1271,6 +1235,65 @@ cdef numeric_t _get_na_val(numeric_t val, bint is_datetimelike):
12711235
return na_val
12721236

12731237

1238+
ctypedef fused mincount_t:
1239+
numeric_t
1240+
complex64_t
1241+
complex128_t
1242+
1243+
1244+
@cython.wraparound(False)
1245+
@cython.boundscheck(False)
1246+
cdef inline void _check_below_mincount(
1247+
mincount_t[:, ::1] out,
1248+
bint uses_mask,
1249+
uint8_t[:, ::1] result_mask,
1250+
Py_ssize_t ncounts,
1251+
Py_ssize_t K,
1252+
int64_t[:, ::1] nobs,
1253+
int64_t min_count,
1254+
mincount_t[:, ::1] resx,
1255+
) nogil:
1256+
"""
1257+
Check if the number of observations for a group is below min_count,
1258+
and if so set the result for that group to the appropriate NA-like value.
1259+
"""
1260+
cdef:
1261+
Py_ssize_t i, j
1262+
1263+
for i in range(ncounts):
1264+
for j in range(K):
1265+
1266+
if nobs[i, j] < min_count:
1267+
# if we are integer dtype, not is_datetimelike, and
1268+
# not uses_mask, then getting here implies that
1269+
# counts[i] < min_count, which means we will
1270+
# be cast to float64 and masked at the end
1271+
# of WrappedCythonOp._call_cython_op. So we can safely
1272+
# set a placeholder value in out[i, j].
1273+
if uses_mask:
1274+
result_mask[i, j] = True
1275+
# set out[i, j] to 0 to be deterministic, as
1276+
# it was initialized with np.empty. Also ensures
1277+
# we can downcast out if appropriate.
1278+
out[i, j] = 0
1279+
elif (
1280+
mincount_t is float32_t
1281+
or mincount_t is float64_t
1282+
or mincount_t is complex64_t
1283+
or mincount_t is complex128_t
1284+
):
1285+
out[i, j] = NAN
1286+
elif mincount_t is int64_t:
1287+
# Per above, this is a placeholder in
1288+
# non-is_datetimelike cases.
1289+
out[i, j] = NPY_NAT
1290+
else:
1291+
# placeholder, see above
1292+
out[i, j] = 0
1293+
else:
1294+
out[i, j] = resx[i, j]
1295+
1296+
12741297
# TODO(cython3): GH#31710 use memorviews once cython 0.30 is released so we can
12751298
# use `const numeric_object_t[:, :] values`
12761299
@cython.wraparound(False)
@@ -1291,8 +1314,8 @@ def group_last(
12911314
cdef:
12921315
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
12931316
numeric_object_t val
1294-
ndarray[numeric_object_t, ndim=2] resx
1295-
ndarray[int64_t, ndim=2] nobs
1317+
numeric_object_t[:, ::1] resx
1318+
int64_t[:, ::1] nobs
12961319
bint uses_mask = mask is not None
12971320
bint isna_entry
12981321

@@ -1327,7 +1350,7 @@ def group_last(
13271350
isna_entry = checknull(val)
13281351

13291352
if not isna_entry:
1330-
# NB: use _treat_as_na here once
1353+
# TODO(cython3): use _treat_as_na here once
13311354
# conditional-nogil is available.
13321355
nobs[lab, j] += 1
13331356
resx[lab, j] = val
@@ -1358,33 +1381,9 @@ def group_last(
13581381
nobs[lab, j] += 1
13591382
resx[lab, j] = val
13601383

1361-
for i in range(ncounts):
1362-
for j in range(K):
1363-
# TODO(cython3): the entire next block can be shared
1364-
# across 3 places once conditional-nogil is available
1365-
if nobs[i, j] < min_count:
1366-
# if we are integer dtype, not is_datetimelike, and
1367-
# not uses_mask, then getting here implies that
1368-
# counts[i] < min_count, which means we will
1369-
# be cast to float64 and masked at the end
1370-
# of WrappedCythonOp._call_cython_op. So we can safely
1371-
# set a placeholder value in out[i, j].
1372-
if uses_mask:
1373-
result_mask[i, j] = True
1374-
elif (
1375-
numeric_object_t is float32_t
1376-
or numeric_object_t is float64_t
1377-
):
1378-
out[i, j] = NAN
1379-
elif numeric_object_t is int64_t:
1380-
# Per above, this is a placeholder in
1381-
# non-is_datetimelike cases.
1382-
out[i, j] = NPY_NAT
1383-
else:
1384-
# placeholder, see above
1385-
out[i, j] = 0
1386-
else:
1387-
out[i, j] = resx[i, j]
1384+
_check_below_mincount(
1385+
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
1386+
)
13881387

13891388

13901389
# TODO(cython3): GH#31710 use memorviews once cython 0.30 is released so we can
@@ -1408,8 +1407,8 @@ def group_nth(
14081407
cdef:
14091408
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
14101409
numeric_object_t val
1411-
ndarray[numeric_object_t, ndim=2] resx
1412-
ndarray[int64_t, ndim=2] nobs
1410+
numeric_object_t[:, ::1] resx
1411+
int64_t[:, ::1] nobs
14131412
bint uses_mask = mask is not None
14141413
bint isna_entry
14151414

@@ -1444,7 +1443,7 @@ def group_nth(
14441443
isna_entry = checknull(val)
14451444

14461445
if not isna_entry:
1447-
# NB: use _treat_as_na here once
1446+
# TODO(cython3): use _treat_as_na here once
14481447
# conditional-nogil is available.
14491448
nobs[lab, j] += 1
14501449
if nobs[lab, j] == rank:
@@ -1478,37 +1477,9 @@ def group_nth(
14781477
if nobs[lab, j] == rank:
14791478
resx[lab, j] = val
14801479

1481-
# TODO: de-dup this whole block with group_last?
1482-
for i in range(ncounts):
1483-
for j in range(K):
1484-
if nobs[i, j] < min_count:
1485-
# if we are integer dtype, not is_datetimelike, and
1486-
# not uses_mask, then getting here implies that
1487-
# counts[i] < min_count, which means we will
1488-
# be cast to float64 and masked at the end
1489-
# of WrappedCythonOp._call_cython_op. So we can safely
1490-
# set a placeholder value in out[i, j].
1491-
if uses_mask:
1492-
result_mask[i, j] = True
1493-
# set out[i, j] to 0 to be deterministic, as
1494-
# it was initialized with np.empty. Also ensures
1495-
# we can downcast out if appropriate.
1496-
out[i, j] = 0
1497-
elif (
1498-
numeric_object_t is float32_t
1499-
or numeric_object_t is float64_t
1500-
):
1501-
out[i, j] = NAN
1502-
elif numeric_object_t is int64_t:
1503-
# Per above, this is a placeholder in
1504-
# non-is_datetimelike cases.
1505-
out[i, j] = NPY_NAT
1506-
else:
1507-
# placeholder, see above
1508-
out[i, j] = 0
1509-
1510-
else:
1511-
out[i, j] = resx[i, j]
1480+
_check_below_mincount(
1481+
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
1482+
)
15121483

15131484

15141485
@cython.boundscheck(False)
@@ -1643,7 +1614,7 @@ cdef group_min_max(
16431614
cdef:
16441615
Py_ssize_t i, j, N, K, lab, ngroups = len(counts)
16451616
numeric_t val
1646-
ndarray[numeric_t, ndim=2] group_min_or_max
1617+
numeric_t[:, ::1] group_min_or_max
16471618
int64_t[:, ::1] nobs
16481619
bint uses_mask = mask is not None
16491620
bint isna_entry
@@ -1685,32 +1656,9 @@ cdef group_min_max(
16851656
if val < group_min_or_max[lab, j]:
16861657
group_min_or_max[lab, j] = val
16871658

1688-
for i in range(ngroups):
1689-
for j in range(K):
1690-
if nobs[i, j] < min_count:
1691-
# if we are integer dtype, not is_datetimelike, and
1692-
# not uses_mask, then getting here implies that
1693-
# counts[i] < min_count, which means we will
1694-
# be cast to float64 and masked at the end
1695-
# of WrappedCythonOp._call_cython_op. So we can safely
1696-
# set a placeholder value in out[i, j].
1697-
if uses_mask:
1698-
result_mask[i, j] = True
1699-
# set out[i, j] to 0 to be deterministic, as
1700-
# it was initialized with np.empty. Also ensures
1701-
# we can downcast out if appropriate.
1702-
out[i, j] = 0
1703-
elif numeric_t is float32_t or numeric_t is float64_t:
1704-
out[i, j] = NAN
1705-
elif numeric_t is int64_t:
1706-
# Per above, this is a placeholder in
1707-
# non-is_datetimelike cases.
1708-
out[i, j] = NPY_NAT
1709-
else:
1710-
# placeholder, see above
1711-
out[i, j] = 0
1712-
else:
1713-
out[i, j] = group_min_or_max[i, j]
1659+
_check_below_mincount(
1660+
out, uses_mask, result_mask, ngroups, K, nobs, min_count, group_min_or_max
1661+
)
17141662

17151663

17161664
@cython.wraparound(False)

0 commit comments

Comments
 (0)