Skip to content

Commit dbe8f7d

Browse files
committed
Grouping matrices are just scatter matrices
1 parent b53ebd9 commit dbe8f7d

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

src/sprog/aggregate.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
"""Pandas-aware aggregation functions."""
22

3-
import numpy as np
43
import pandas as pd
54
from pandas.api.typing import SeriesGroupBy
6-
from scipy import sparse
5+
6+
from sprog.sparse import scatter
77

88

99
def sum(groups: SeriesGroupBy) -> pd.Series: # noqa: A001
1010
"""Faster sum aggregates for LinearVariableArray series."""
1111
row = groups.ngroup().to_numpy()
1212
return groups._wrap_applied_output( # noqa: SLF001
1313
data=groups.obj,
14-
values=sparse.csr_array(
15-
(np.ones(row.size), (row, np.arange(row.size, dtype=row.dtype)))
16-
)
17-
@ groups.obj.array,
14+
values=scatter(row, grouping=True) @ groups.obj.array,
1815
not_indexed_same=True,
1916
is_transform=False,
2017
)

src/sprog/sparse.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
from scipy import sparse
1212

1313

14-
def scatter(indices: Sequence[Integral], m: int = -1, n: int = -1) -> sparse.csr_array:
14+
def scatter(
15+
indices: Sequence[Integral], *, m: int = -1, n: int = -1, grouping: bool = False
16+
) -> sparse.csr_array:
1517
"""`Scatter`_ consecutive indices of x into (larger) result vector y.
1618
1719
Args:
1820
indices: subset of range to populate (rest will be 0)
1921
m: length of range (defaults to :code:`max(indices) + 1`)
2022
n: length of domain (defaults to :code:`len(indices)`)
23+
grouping: skip :code:`m >= n` assertion (defaults to :code:`False`)
2124
2225
Returns:
2326
sparse array in CSR format
@@ -39,11 +42,13 @@ def scatter(indices: Sequence[Integral], m: int = -1, n: int = -1) -> sparse.csr
3942
n = k
4043
assert m >= max(indices) + 1
4144
assert n >= k
42-
assert m >= n
45+
assert grouping or m >= n
4346
return sparse.csr_array((np.ones(shape=k), (indices, range(k))), shape=(m, n))
4447

4548

46-
def gather(indices: Sequence[Integral], m: int = -1, n: int = -1) -> sparse.csr_array:
49+
def gather(
50+
indices: Sequence[Integral], *, m: int = -1, n: int = -1
51+
) -> sparse.csr_array:
4752
"""`Gather`_ subset of x into (smaller) consecutive result vector y.
4853
4954
Args:

0 commit comments

Comments
 (0)