Skip to content

Commit 6c8b56a

Browse files
committed
ENH: cythonize groupby count
1 parent e94e38a commit 6c8b56a

File tree

3 files changed

+561
-13
lines changed

3 files changed

+561
-13
lines changed

pandas/core/groupby.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import collections
66

77
from pandas.compat import(
8-
zip, builtins, range, long, lrange, lzip,
8+
zip, builtins, range, long, lzip,
99
OrderedDict, callable
1010
)
1111
from pandas import compat
@@ -713,15 +713,6 @@ def size(self):
713713
"""
714714
return self.grouper.size()
715715

716-
def count(self, axis=0):
717-
"""
718-
Number of non-null items in each group.
719-
axis : axis number, default 0
720-
the grouping axis
721-
"""
722-
self._set_selection_from_grouper()
723-
return self._python_agg_general(lambda x: notnull(x).sum(axis=axis)).astype('int64')
724-
725716
sum = _groupby_function('sum', 'add', np.sum)
726717
prod = _groupby_function('prod', 'prod', np.prod)
727718
min = _groupby_function('min', 'min', np.min, numeric_only=False)
@@ -731,6 +722,11 @@ def count(self, axis=0):
731722
last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
732723
_convert=True)
733724

725+
_count = _groupby_function('_count', 'count',
726+
lambda x, axis=0: notnull(x).sum(axis=axis))
727+
728+
def count(self, axis=0):
729+
return self._count().astype('int64')
734730

735731
def ohlc(self):
736732
"""
@@ -1318,10 +1314,11 @@ def get_group_levels(self):
13181314
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
13191315
},
13201316
'last': 'group_last',
1317+
'count': 'group_count',
13211318
}
13221319

13231320
_cython_transforms = {
1324-
'std': np.sqrt
1321+
'std': np.sqrt,
13251322
}
13261323

13271324
_cython_arity = {
@@ -1651,6 +1648,7 @@ def names(self):
16511648
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
16521649
},
16531650
'last': 'group_last_bin',
1651+
'count': 'group_count_bin',
16541652
}
16551653

16561654
_name_functions = {

pandas/src/generate_code.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# don't introduce a pandas/pandas.compat import
44
# or we get a bootstrapping problem
55
from StringIO import StringIO
6-
import os
76

87
header = """
98
cimport numpy as np
@@ -1150,6 +1149,86 @@ def group_var_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11501149
(ct * ct - ct))
11511150
"""
11521151

1152+
group_count_template = """@cython.boundscheck(False)
1153+
@cython.wraparound(False)
1154+
def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1155+
ndarray[int64_t] counts,
1156+
ndarray[%(c_type)s, ndim=2] values,
1157+
ndarray[int64_t] labels):
1158+
'''
1159+
Only aggregates on axis=0
1160+
'''
1161+
cdef:
1162+
Py_ssize_t i, j, N, K, lab
1163+
%(dest_type2)s val
1164+
ndarray[%(dest_type2)s, ndim=2] nobs = np.zeros_like(out)
1165+
1166+
1167+
if not len(values) == len(labels):
1168+
raise AssertionError("len(index) != len(labels)")
1169+
1170+
N, K = (<object> values).shape
1171+
1172+
for i in range(N):
1173+
lab = labels[i]
1174+
if lab < 0:
1175+
continue
1176+
1177+
counts[lab] += 1
1178+
for j in range(K):
1179+
val = values[i, j]
1180+
1181+
# not nan
1182+
nobs[lab, j] += val == val
1183+
1184+
for i in range(len(counts)):
1185+
for j in range(K):
1186+
out[i, j] = nobs[i, j]
1187+
1188+
1189+
"""
1190+
1191+
group_count_bin_template = """@cython.boundscheck(False)
1192+
@cython.wraparound(False)
1193+
def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1194+
ndarray[int64_t] counts,
1195+
ndarray[%(c_type)s, ndim=2] values,
1196+
ndarray[int64_t] bins):
1197+
'''
1198+
Only aggregates on axis=0
1199+
'''
1200+
cdef:
1201+
Py_ssize_t i, j, N, K, ngroups, b
1202+
%(dest_type2)s val, count
1203+
ndarray[%(dest_type2)s, ndim=2] nobs
1204+
1205+
nobs = np.zeros_like(out)
1206+
1207+
if bins[len(bins) - 1] == len(values):
1208+
ngroups = len(bins)
1209+
else:
1210+
ngroups = len(bins) + 1
1211+
1212+
N, K = (<object> values).shape
1213+
1214+
b = 0
1215+
for i in range(N):
1216+
while b < ngroups - 1 and i >= bins[b]:
1217+
b += 1
1218+
1219+
counts[b] += 1
1220+
for j in range(K):
1221+
val = values[i, j]
1222+
1223+
# not nan
1224+
nobs[b, j] += val == val
1225+
1226+
for i in range(ngroups):
1227+
for j in range(K):
1228+
out[i, j] = nobs[i, j]
1229+
1230+
1231+
"""
11531232
# add passing bin edges, instead of labels
11541233

11551234

@@ -2251,6 +2330,8 @@ def generate_from_template(template, exclude=None):
22512330
group_max_bin_template,
22522331
group_ohlc_template]
22532332

2333+
groupby_count = [group_count_template, group_count_bin_template]
2334+
22542335
templates_1d = [map_indices_template,
22552336
pad_template,
22562337
backfill_template,
@@ -2272,6 +2353,7 @@ def generate_from_template(template, exclude=None):
22722353
take_2d_axis1_template,
22732354
take_2d_multi_template]
22742355

2356+
22752357
def generate_take_cython_file(path='generated.pyx'):
22762358
with open(path, 'w') as f:
22772359
print(header, file=f)
@@ -2288,7 +2370,10 @@ def generate_take_cython_file(path='generated.pyx'):
22882370
print(generate_put_template(template), file=f)
22892371

22902372
for template in groupbys:
2291-
print(generate_put_template(template, use_ints = False), file=f)
2373+
print(generate_put_template(template, use_ints=False), file=f)
2374+
2375+
for template in groupby_count:
2376+
print(generate_put_template(template), file=f)
22922377

22932378
# for template in templates_1d_datetime:
22942379
# print >> f, generate_from_template_datetime(template)
@@ -2299,5 +2384,6 @@ def generate_take_cython_file(path='generated.pyx'):
22992384
for template in nobool_1d_templates:
23002385
print(generate_from_template(template, exclude=['bool']), file=f)
23012386

2387+
23022388
if __name__ == '__main__':
23032389
generate_take_cython_file()

0 commit comments

Comments
 (0)