Skip to content

Commit 74ed1d4

Browse files
authored
grouped array fixes (#35)
1 parent 20c1dde commit 74ed1d4

File tree

2 files changed

+125
-47
lines changed

2 files changed

+125
-47
lines changed

nbs/grouped_array.ipynb

Lines changed: 102 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141
") -> Tuple[np.ndarray, np.ndarray]:\n",
4242
" \"\"\"Append each value of new to each group in data formed by indptr.\"\"\"\n",
4343
" n_groups = len(indptr) - 1\n",
44-
" rows = data.shape[0] + new.shape[0]\n",
45-
" new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)\n",
44+
" n_rows = data.shape[0] + new.shape[0]\n",
45+
" if data.ndim == 2:\n",
46+
" new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))\n",
47+
" else:\n",
48+
" new_data = np.empty_like(data, shape=n_rows)\n",
4649
" new_indptr = indptr.copy()\n",
4750
" new_indptr[1:] += np.arange(1, n_groups + 1)\n",
4851
" for i in range(n_groups):\n",
@@ -61,13 +64,25 @@
6164
"outputs": [],
6265
"source": [
6366
"# test _append_one\n",
64-
"data = np.arange(5).reshape(-1, 1)\n",
67+
"data = np.arange(5)\n",
6568
"indptr = np.array([0, 2, 5])\n",
6669
"new = np.array([7, 8])\n",
6770
"new_data, new_indptr = _append_one(data, indptr, new)\n",
6871
"np.testing.assert_equal(\n",
6972
" new_data,\n",
70-
" np.array([0, 1, 7, 2, 3, 4, 8]).reshape(-1, 1),\n",
73+
" np.array([0, 1, 7, 2, 3, 4, 8])\n",
74+
")\n",
75+
"np.testing.assert_equal(\n",
76+
" new_indptr,\n",
77+
" np.array([0, 3, 7]),\n",
78+
")\n",
79+
"\n",
80+
"# 2d\n",
81+
"data = np.arange(5).reshape(-1, 1)\n",
82+
"new_data, new_indptr = _append_one(data, indptr, new)\n",
83+
"np.testing.assert_equal(\n",
84+
" new_data,\n",
85+
" np.array([0, 1, 7, 2, 3, 4, 8]).reshape(-1, 1)\n",
7186
")\n",
7287
"np.testing.assert_equal(\n",
7388
" new_indptr,\n",
@@ -90,9 +105,12 @@
90105
" new_values: np.ndarray,\n",
91106
" new_groups: np.ndarray,\n",
92107
") -> Tuple[np.ndarray, np.ndarray]:\n",
93-
" rows = data.shape[0] + new_values.shape[0]\n",
94-
" new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)\n",
95-
" new_indptr = np.empty(new_sizes.size + 1, dtype=indptr.dtype)\n",
108+
" n_rows = data.shape[0] + new_values.shape[0]\n",
109+
" if data.ndim == 2:\n",
110+
" new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))\n",
111+
" else:\n",
112+
" new_data = np.empty_like(data, shape=n_rows)\n",
113+
" new_indptr = np.empty_like(indptr, shape=new_sizes.size + 1)\n",
96114
" new_indptr[0] = 0\n",
97115
" old_indptr_idx = 0\n",
98116
" new_vals_idx = 0\n",
@@ -122,6 +140,22 @@
122140
"outputs": [],
123141
"source": [
124142
"# test append several\n",
143+
"data = np.arange(5)\n",
144+
"indptr = np.array([0, 2, 5])\n",
145+
"new_sizes = np.array([0, 2, 1])\n",
146+
"new_values = np.array([6, 7, 5])\n",
147+
"new_groups = np.array([False, True, False])\n",
148+
"new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups)\n",
149+
"np.testing.assert_equal(\n",
150+
" new_data,\n",
151+
" np.array([0, 1, 6, 7, 2, 3, 4, 5])\n",
152+
")\n",
153+
"np.testing.assert_equal(\n",
154+
" new_indptr,\n",
155+
" np.array([0, 2, 4, 8]),\n",
156+
")\n",
157+
"\n",
158+
"# 2d\n",
125159
"data = np.arange(5).reshape(-1, 1)\n",
126160
"indptr = np.array([0, 2, 5])\n",
127161
"new_sizes = np.array([0, 2, 1])\n",
@@ -130,7 +164,7 @@
130164
"new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups)\n",
131165
"np.testing.assert_equal(\n",
132166
" new_data,\n",
133-
" np.array([0, 1, 6, 7, 2, 3, 4, 5]).reshape(-1, 1),\n",
167+
" np.array([0, 1, 6, 7, 2, 3, 4, 5]).reshape(-1, 1)\n",
134168
")\n",
135169
"np.testing.assert_equal(\n",
136170
" new_indptr,\n",
@@ -172,19 +206,22 @@
172206
" data = data.astype(np.float32)\n",
173207
" return cls(data, indptr)\n",
174208
"\n",
175-
" def _take_from_ranges(self, ranges: Sequence) -> 'GroupedArray':\n",
209+
" def _take_from_ranges(self, ranges: Sequence) -> Tuple[np.ndarray, np.ndarray]:\n",
176210
" items = [self.data[r] for r in ranges]\n",
177211
" sizes = np.array([item.shape[0] for item in items])\n",
178-
" data = np.vstack(items)\n",
212+
" if self.data.ndim == 2:\n",
213+
" data = np.vstack(items)\n",
214+
" else:\n",
215+
" data = np.hstack(items)\n",
179216
" indptr = np.append(0, sizes.cumsum())\n",
180-
" return GroupedArray(data, indptr) \n",
217+
" return data, indptr\n",
181218
"\n",
182-
" def take(self, idxs: Sequence[int]) -> 'GroupedArray':\n",
219+
" def take(self, idxs: Sequence[int]) -> Tuple[np.ndarray, np.ndarray]:\n",
183220
" \"\"\"Subset specific groups by their indices.\"\"\"\n",
184221
" ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs]\n",
185222
" return self._take_from_ranges(ranges)\n",
186223
"\n",
187-
" def take_from_groups(self, idx: Union[int, slice]) -> 'GroupedArray':\n",
224+
" def take_from_groups(self, idx: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:\n",
188225
" \"\"\"Select a subset from each group.\"\"\"\n",
189226
" if isinstance(idx, int):\n",
190227
" # this preserves the 2d structure of data when indexing with the range\n",
@@ -195,20 +232,18 @@
195232
" ]\n",
196233
" return self._take_from_ranges(ranges)\n",
197234
"\n",
198-
" def append(self, new: np.ndarray) -> 'GroupedArray':\n",
235+
" def append(self, new: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n",
199236
" \"\"\"Appends each element of `new` to each existing group. Returns a copy.\"\"\"\n",
200237
" if new.shape[0] != self.n_groups:\n",
201238
" raise ValueError(f\"new must have {self.n_groups} rows.\")\n",
202-
" new_data, new_indptr = _append_one(self.data, self.indptr, new)\n",
203-
" return GroupedArray(new_data, new_indptr)\n",
239+
" return _append_one(self.data, self.indptr, new)\n",
204240
"\n",
205241
" def append_several(\n",
206242
" self, new_sizes: np.ndarray, new_values: np.ndarray, new_groups: np.ndarray\n",
207-
" ) -> \"GroupedArray\":\n",
208-
" new_data, new_indptr = _append_several(\n",
243+
" ) -> Tuple[np.ndarray, np.ndarray]:\n",
244+
" return _append_several(\n",
209245
" self.data, self.indptr, new_sizes, new_values, new_groups\n",
210246
" )\n",
211-
" return GroupedArray(new_data, new_indptr)\n",
212247
"\n",
213248
" def __repr__(self):\n",
214249
" return (\n",
@@ -258,25 +293,47 @@
258293
{
259294
"cell_type": "code",
260295
"execution_count": null,
261-
"id": "c3e2ea52-72c2-4b6a-aa17-deda48121c59",
296+
"id": "91761fea-19d7-4707-b4db-e74ba152010b",
262297
"metadata": {},
263298
"outputs": [],
264299
"source": [
265300
"# Take the last two observations from each group\n",
266-
"last_2 = ga.take_from_groups(slice(-2, None))\n",
301+
"last2_data, last2_indptr = ga.take_from_groups(slice(-2, None))\n",
267302
"np.testing.assert_equal(\n",
268-
" last_2.data,\n",
303+
" last2_data,\n",
269304
" np.vstack([\n",
270305
" np.arange(4).reshape(-1, 2),\n",
271306
" np.arange(16, 20).reshape(-1, 2),\n",
272307
" ]),\n",
273308
")\n",
274-
"np.testing.assert_equal(last_2.indptr, np.array([0, 2, 4]))\n",
309+
"np.testing.assert_equal(last2_indptr, np.array([0, 2, 4]))\n",
275310
"\n",
311+
"# 1d\n",
312+
"ga1d = GroupedArray(np.arange(10), indptr)\n",
313+
"last2_data1d, last2_indptr1d = ga1d.take_from_groups(slice(-2, None))\n",
314+
"np.testing.assert_equal(\n",
315+
" last2_data1d,\n",
316+
" np.array([0, 1, 8, 9])\n",
317+
")\n",
318+
"np.testing.assert_equal(last2_indptr1d, np.array([0, 2, 4]))"
319+
]
320+
},
321+
{
322+
"cell_type": "code",
323+
"execution_count": null,
324+
"id": "c3d635e1-9194-4547-8be9-2452b1f4f21e",
325+
"metadata": {},
326+
"outputs": [],
327+
"source": [
276328
"# Take the second observation from each group\n",
277-
"second = ga.take_from_groups(1)\n",
278-
"np.testing.assert_equal(second.data, np.array([[2, 3], [6, 7]]))\n",
279-
"np.testing.assert_equal(second.indptr, np.array([0, 1, 2]))"
329+
"second_data, second_indptr = ga.take_from_groups(1)\n",
330+
"np.testing.assert_equal(second_data, np.array([[2, 3], [6, 7]]))\n",
331+
"np.testing.assert_equal(second_indptr, np.array([0, 1, 2]))\n",
332+
"\n",
333+
"# 1d\n",
334+
"second_data1d, second_indptr1d = ga1d.take_from_groups(1)\n",
335+
"np.testing.assert_equal(second_data1d, np.array([1, 3]))\n",
336+
"np.testing.assert_equal(second_indptr1d, np.array([0, 1, 2]))"
280337
]
281338
},
282339
{
@@ -287,15 +344,23 @@
287344
"outputs": [],
288345
"source": [
289346
"# Take the last four observations from every group. Note that since group 1 only has two elements, only these are returned.\n",
290-
"last_4 = ga.take_from_groups(slice(-4, None))\n",
347+
"last4_data, last4_indptr = ga.take_from_groups(slice(-4, None))\n",
291348
"np.testing.assert_equal(\n",
292-
" last_4.data,\n",
349+
" last4_data,\n",
293350
" np.vstack([\n",
294351
" np.arange(4).reshape(-1, 2),\n",
295352
" np.arange(12, 20).reshape(-1, 2),\n",
296353
" ]),\n",
297354
")\n",
298-
"np.testing.assert_equal(last_4.indptr, np.array([0, 2, 6]))"
355+
"np.testing.assert_equal(last4_indptr, np.array([0, 2, 6]))\n",
356+
"\n",
357+
"# 1d\n",
358+
"last4_data1d, last4_indptr1d = ga1d.take_from_groups(slice(-4, None))\n",
359+
"np.testing.assert_equal(\n",
360+
" last4_data1d,\n",
361+
" np.array([0, 1, 6, 7, 8, 9])\n",
362+
")\n",
363+
"np.testing.assert_equal(last4_indptr1d, np.array([0, 2, 6]))"
299364
]
300365
},
301366
{
@@ -308,9 +373,15 @@
308373
"# Select a specific subset of groups\n",
309374
"indptr = np.array([0, 2, 4, 7, 10])\n",
310375
"ga2 = GroupedArray(data, indptr)\n",
311-
"subset = ga2.take([0, 2])\n",
376+
"subset = GroupedArray(*ga2.take([0, 2]))\n",
312377
"np.testing.assert_allclose(subset[0].data, ga2[0].data)\n",
313-
"np.testing.assert_allclose(subset[1].data, ga2[2].data)"
378+
"np.testing.assert_allclose(subset[1].data, ga2[2].data)\n",
379+
"\n",
380+
"# 1d\n",
381+
"ga2_1d = GroupedArray(np.arange(10), indptr)\n",
382+
"subset1d = GroupedArray(*ga2_1d.take([0, 2]))\n",
383+
"np.testing.assert_allclose(subset1d[0].data, ga2_1d[0].data)\n",
384+
"np.testing.assert_allclose(subset1d[1].data, ga2_1d[2].data)"
314385
]
315386
},
316387
{

utilsforecast/grouped_array.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ def _append_one(
1717
) -> Tuple[np.ndarray, np.ndarray]:
1818
"""Append each value of new to each group in data formed by indptr."""
1919
n_groups = len(indptr) - 1
20-
rows = data.shape[0] + new.shape[0]
21-
new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)
20+
n_rows = data.shape[0] + new.shape[0]
21+
if data.ndim == 2:
22+
new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))
23+
else:
24+
new_data = np.empty_like(data, shape=n_rows)
2225
new_indptr = indptr.copy()
2326
new_indptr[1:] += np.arange(1, n_groups + 1)
2427
for i in range(n_groups):
@@ -36,9 +39,12 @@ def _append_several(
3639
new_values: np.ndarray,
3740
new_groups: np.ndarray,
3841
) -> Tuple[np.ndarray, np.ndarray]:
39-
rows = data.shape[0] + new_values.shape[0]
40-
new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)
41-
new_indptr = np.empty(new_sizes.size + 1, dtype=indptr.dtype)
42+
n_rows = data.shape[0] + new_values.shape[0]
43+
if data.ndim == 2:
44+
new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))
45+
else:
46+
new_data = np.empty_like(data, shape=n_rows)
47+
new_indptr = np.empty_like(indptr, shape=new_sizes.size + 1)
4248
new_indptr[0] = 0
4349
old_indptr_idx = 0
4450
new_vals_idx = 0
@@ -86,19 +92,22 @@ def from_sorted_df(
8692
data = data.astype(np.float32)
8793
return cls(data, indptr)
8894

89-
def _take_from_ranges(self, ranges: Sequence) -> "GroupedArray":
95+
def _take_from_ranges(self, ranges: Sequence) -> Tuple[np.ndarray, np.ndarray]:
9096
items = [self.data[r] for r in ranges]
9197
sizes = np.array([item.shape[0] for item in items])
92-
data = np.vstack(items)
98+
if self.data.ndim == 2:
99+
data = np.vstack(items)
100+
else:
101+
data = np.hstack(items)
93102
indptr = np.append(0, sizes.cumsum())
94-
return GroupedArray(data, indptr)
103+
return data, indptr
95104

96-
def take(self, idxs: Sequence[int]) -> "GroupedArray":
105+
def take(self, idxs: Sequence[int]) -> Tuple[np.ndarray, np.ndarray]:
97106
"""Subset specific groups by their indices."""
98107
ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs]
99108
return self._take_from_ranges(ranges)
100109

101-
def take_from_groups(self, idx: Union[int, slice]) -> "GroupedArray":
110+
def take_from_groups(self, idx: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
102111
"""Select a subset from each group."""
103112
if isinstance(idx, int):
104113
# this preserves the 2d structure of data when indexing with the range
@@ -108,20 +117,18 @@ def take_from_groups(self, idx: Union[int, slice]) -> "GroupedArray":
108117
]
109118
return self._take_from_ranges(ranges)
110119

111-
def append(self, new: np.ndarray) -> "GroupedArray":
120+
def append(self, new: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
112121
"""Appends each element of `new` to each existing group. Returns a copy."""
113122
if new.shape[0] != self.n_groups:
114123
raise ValueError(f"new must have {self.n_groups} rows.")
115-
new_data, new_indptr = _append_one(self.data, self.indptr, new)
116-
return GroupedArray(new_data, new_indptr)
124+
return _append_one(self.data, self.indptr, new)
117125

118126
def append_several(
119127
self, new_sizes: np.ndarray, new_values: np.ndarray, new_groups: np.ndarray
120-
) -> "GroupedArray":
121-
new_data, new_indptr = _append_several(
128+
) -> Tuple[np.ndarray, np.ndarray]:
129+
return _append_several(
122130
self.data, self.indptr, new_sizes, new_values, new_groups
123131
)
124-
return GroupedArray(new_data, new_indptr)
125132

126133
def __repr__(self):
127134
return f"{self.__class__.__name__}(n_rows={self.data.shape[0]:,}, n_groups={self.n_groups:,})"

0 commit comments

Comments
 (0)