|
41 | 41 | ") -> Tuple[np.ndarray, np.ndarray]:\n", |
42 | 42 | " \"\"\"Append each value of new to each group in data formed by indptr.\"\"\"\n", |
43 | 43 | " n_groups = len(indptr) - 1\n", |
44 | | - " rows = data.shape[0] + new.shape[0]\n", |
45 | | - " new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)\n", |
| 44 | + " n_rows = data.shape[0] + new.shape[0]\n", |
| 45 | + " if data.ndim == 2:\n", |
| 46 | + " new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))\n", |
| 47 | + " else:\n", |
| 48 | + " new_data = np.empty_like(data, shape=n_rows)\n", |
46 | 49 | " new_indptr = indptr.copy()\n", |
47 | 50 | " new_indptr[1:] += np.arange(1, n_groups + 1)\n", |
48 | 51 | " for i in range(n_groups):\n", |
|
61 | 64 | "outputs": [], |
62 | 65 | "source": [ |
63 | 66 | "# test _append_one\n", |
64 | | - "data = np.arange(5).reshape(-1, 1)\n", |
| 67 | + "data = np.arange(5)\n", |
65 | 68 | "indptr = np.array([0, 2, 5])\n", |
66 | 69 | "new = np.array([7, 8])\n", |
67 | 70 | "new_data, new_indptr = _append_one(data, indptr, new)\n", |
68 | 71 | "np.testing.assert_equal(\n", |
69 | 72 | " new_data,\n", |
70 | | - " np.array([0, 1, 7, 2, 3, 4, 8]).reshape(-1, 1),\n", |
| 73 | + " np.array([0, 1, 7, 2, 3, 4, 8])\n", |
| 74 | + ")\n", |
| 75 | + "np.testing.assert_equal(\n", |
| 76 | + " new_indptr,\n", |
| 77 | + " np.array([0, 3, 7]),\n", |
| 78 | + ")\n", |
| 79 | + "\n", |
| 80 | + "# 2d\n", |
| 81 | + "data = np.arange(5).reshape(-1, 1)\n", |
| 82 | + "new_data, new_indptr = _append_one(data, indptr, new)\n", |
| 83 | + "np.testing.assert_equal(\n", |
| 84 | + " new_data,\n", |
| 85 | + " np.array([0, 1, 7, 2, 3, 4, 8]).reshape(-1, 1)\n", |
71 | 86 | ")\n", |
72 | 87 | "np.testing.assert_equal(\n", |
73 | 88 | " new_indptr,\n", |
|
90 | 105 | " new_values: np.ndarray,\n", |
91 | 106 | " new_groups: np.ndarray,\n", |
92 | 107 | ") -> Tuple[np.ndarray, np.ndarray]:\n", |
93 | | - " rows = data.shape[0] + new_values.shape[0]\n", |
94 | | - " new_data = np.empty((rows, data.shape[1]), dtype=data.dtype)\n", |
95 | | - " new_indptr = np.empty(new_sizes.size + 1, dtype=indptr.dtype)\n", |
| 108 | + " n_rows = data.shape[0] + new_values.shape[0]\n", |
| 109 | + " if data.ndim == 2:\n", |
| 110 | + " new_data = np.empty_like(data, shape=(n_rows, data.shape[1]))\n", |
| 111 | + " else:\n", |
| 112 | + " new_data = np.empty_like(data, shape=n_rows)\n", |
| 113 | + " new_indptr = np.empty_like(indptr, shape=new_sizes.size + 1)\n", |
96 | 114 | " new_indptr[0] = 0\n", |
97 | 115 | " old_indptr_idx = 0\n", |
98 | 116 | " new_vals_idx = 0\n", |
|
122 | 140 | "outputs": [], |
123 | 141 | "source": [ |
124 | 142 | "# test append several\n", |
| 143 | + "data = np.arange(5)\n", |
| 144 | + "indptr = np.array([0, 2, 5])\n", |
| 145 | + "new_sizes = np.array([0, 2, 1])\n", |
| 146 | + "new_values = np.array([6, 7, 5])\n", |
| 147 | + "new_groups = np.array([False, True, False])\n", |
| 148 | + "new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups)\n", |
| 149 | + "np.testing.assert_equal(\n", |
| 150 | + " new_data,\n", |
| 151 | + " np.array([0, 1, 6, 7, 2, 3, 4, 5])\n", |
| 152 | + ")\n", |
| 153 | + "np.testing.assert_equal(\n", |
| 154 | + " new_indptr,\n", |
| 155 | + " np.array([0, 2, 4, 8]),\n", |
| 156 | + ")\n", |
| 157 | + "\n", |
| 158 | + "# 2d\n", |
125 | 159 | "data = np.arange(5).reshape(-1, 1)\n", |
126 | 160 | "indptr = np.array([0, 2, 5])\n", |
127 | 161 | "new_sizes = np.array([0, 2, 1])\n", |
|
130 | 164 | "new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups)\n", |
131 | 165 | "np.testing.assert_equal(\n", |
132 | 166 | " new_data,\n", |
133 | | - " np.array([0, 1, 6, 7, 2, 3, 4, 5]).reshape(-1, 1),\n", |
| 167 | + " np.array([0, 1, 6, 7, 2, 3, 4, 5]).reshape(-1, 1)\n", |
134 | 168 | ")\n", |
135 | 169 | "np.testing.assert_equal(\n", |
136 | 170 | " new_indptr,\n", |
|
172 | 206 | " data = data.astype(np.float32)\n", |
173 | 207 | " return cls(data, indptr)\n", |
174 | 208 | "\n", |
175 | | - " def _take_from_ranges(self, ranges: Sequence) -> 'GroupedArray':\n", |
| 209 | + " def _take_from_ranges(self, ranges: Sequence) -> Tuple[np.ndarray, np.ndarray]:\n", |
176 | 210 | " items = [self.data[r] for r in ranges]\n", |
177 | 211 | " sizes = np.array([item.shape[0] for item in items])\n", |
178 | | - " data = np.vstack(items)\n", |
| 212 | + " if self.data.ndim == 2:\n", |
| 213 | + " data = np.vstack(items)\n", |
| 214 | + " else:\n", |
| 215 | + " data = np.hstack(items)\n", |
179 | 216 | " indptr = np.append(0, sizes.cumsum())\n", |
180 | | - " return GroupedArray(data, indptr) \n", |
| 217 | + " return data, indptr\n", |
181 | 218 | "\n", |
182 | | - " def take(self, idxs: Sequence[int]) -> 'GroupedArray':\n", |
| 219 | + " def take(self, idxs: Sequence[int]) -> Tuple[np.ndarray, np.ndarray]:\n", |
183 | 220 | " \"\"\"Subset specific groups by their indices.\"\"\"\n", |
184 | 221 | " ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs]\n", |
185 | 222 | " return self._take_from_ranges(ranges)\n", |
186 | 223 | "\n", |
187 | | - " def take_from_groups(self, idx: Union[int, slice]) -> 'GroupedArray':\n", |
| 224 | + " def take_from_groups(self, idx: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:\n", |
188 | 225 | " \"\"\"Select a subset from each group.\"\"\"\n", |
189 | 226 | " if isinstance(idx, int):\n", |
190 | 227 | " # this preserves the 2d structure of data when indexing with the range\n", |
|
195 | 232 | " ]\n", |
196 | 233 | " return self._take_from_ranges(ranges)\n", |
197 | 234 | "\n", |
198 | | - " def append(self, new: np.ndarray) -> 'GroupedArray':\n", |
| 235 | + " def append(self, new: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n", |
199 | 236 | " \"\"\"Appends each element of `new` to each existing group. Returns a copy.\"\"\"\n", |
200 | 237 | " if new.shape[0] != self.n_groups:\n", |
201 | 238 | " raise ValueError(f\"new must have {self.n_groups} rows.\")\n", |
202 | | - " new_data, new_indptr = _append_one(self.data, self.indptr, new)\n", |
203 | | - " return GroupedArray(new_data, new_indptr)\n", |
| 239 | + " return _append_one(self.data, self.indptr, new)\n", |
204 | 240 | "\n", |
205 | 241 | " def append_several(\n", |
206 | 242 | " self, new_sizes: np.ndarray, new_values: np.ndarray, new_groups: np.ndarray\n", |
207 | | - " ) -> \"GroupedArray\":\n", |
208 | | - " new_data, new_indptr = _append_several(\n", |
| 243 | + " ) -> Tuple[np.ndarray, np.ndarray]:\n", |
| 244 | + " return _append_several(\n", |
209 | 245 | " self.data, self.indptr, new_sizes, new_values, new_groups\n", |
210 | 246 | " )\n", |
211 | | - " return GroupedArray(new_data, new_indptr)\n", |
212 | 247 | "\n", |
213 | 248 | " def __repr__(self):\n", |
214 | 249 | " return (\n", |
|
258 | 293 | { |
259 | 294 | "cell_type": "code", |
260 | 295 | "execution_count": null, |
261 | | - "id": "c3e2ea52-72c2-4b6a-aa17-deda48121c59", |
| 296 | + "id": "91761fea-19d7-4707-b4db-e74ba152010b", |
262 | 297 | "metadata": {}, |
263 | 298 | "outputs": [], |
264 | 299 | "source": [ |
265 | 300 | "# Take the last two observations from each group\n", |
266 | | - "last_2 = ga.take_from_groups(slice(-2, None))\n", |
| 301 | + "last2_data, last2_indptr = ga.take_from_groups(slice(-2, None))\n", |
267 | 302 | "np.testing.assert_equal(\n", |
268 | | - " last_2.data,\n", |
| 303 | + " last2_data,\n", |
269 | 304 | " np.vstack([\n", |
270 | 305 | " np.arange(4).reshape(-1, 2),\n", |
271 | 306 | " np.arange(16, 20).reshape(-1, 2),\n", |
272 | 307 | " ]),\n", |
273 | 308 | ")\n", |
274 | | - "np.testing.assert_equal(last_2.indptr, np.array([0, 2, 4]))\n", |
| 309 | + "np.testing.assert_equal(last2_indptr, np.array([0, 2, 4]))\n", |
275 | 310 | "\n", |
| 311 | + "# 1d\n", |
| 312 | + "ga1d = GroupedArray(np.arange(10), indptr)\n", |
| 313 | + "last2_data1d, last2_indptr1d = ga1d.take_from_groups(slice(-2, None))\n", |
| 314 | + "np.testing.assert_equal(\n", |
| 315 | + " last2_data1d,\n", |
| 316 | + " np.array([0, 1, 8, 9])\n", |
| 317 | + ")\n", |
| 318 | + "np.testing.assert_equal(last2_indptr1d, np.array([0, 2, 4]))" |
| 319 | + ] |
| 320 | + }, |
| 321 | + { |
| 322 | + "cell_type": "code", |
| 323 | + "execution_count": null, |
| 324 | + "id": "c3d635e1-9194-4547-8be9-2452b1f4f21e", |
| 325 | + "metadata": {}, |
| 326 | + "outputs": [], |
| 327 | + "source": [ |
276 | 328 | "# Take the second observation from each group\n", |
277 | | - "second = ga.take_from_groups(1)\n", |
278 | | - "np.testing.assert_equal(second.data, np.array([[2, 3], [6, 7]]))\n", |
279 | | - "np.testing.assert_equal(second.indptr, np.array([0, 1, 2]))" |
| 329 | + "second_data, second_indptr = ga.take_from_groups(1)\n", |
| 330 | + "np.testing.assert_equal(second_data, np.array([[2, 3], [6, 7]]))\n", |
| 331 | + "np.testing.assert_equal(second_indptr, np.array([0, 1, 2]))\n", |
| 332 | + "\n", |
| 333 | + "# 1d\n", |
| 334 | + "second_data1d, second_indptr1d = ga1d.take_from_groups(1)\n", |
| 335 | + "np.testing.assert_equal(second_data1d, np.array([1, 3]))\n", |
| 336 | + "np.testing.assert_equal(second_indptr1d, np.array([0, 1, 2]))" |
280 | 337 | ] |
281 | 338 | }, |
282 | 339 | { |
|
287 | 344 | "outputs": [], |
288 | 345 | "source": [ |
289 | 346 | "# Take the last four observations from every group. Note that since group 1 only has two elements, only these are returned.\n", |
290 | | - "last_4 = ga.take_from_groups(slice(-4, None))\n", |
| 347 | + "last4_data, last4_indptr = ga.take_from_groups(slice(-4, None))\n", |
291 | 348 | "np.testing.assert_equal(\n", |
292 | | - " last_4.data,\n", |
| 349 | + " last4_data,\n", |
293 | 350 | " np.vstack([\n", |
294 | 351 | " np.arange(4).reshape(-1, 2),\n", |
295 | 352 | " np.arange(12, 20).reshape(-1, 2),\n", |
296 | 353 | " ]),\n", |
297 | 354 | ")\n", |
298 | | - "np.testing.assert_equal(last_4.indptr, np.array([0, 2, 6]))" |
| 355 | + "np.testing.assert_equal(last4_indptr, np.array([0, 2, 6]))\n", |
| 356 | + "\n", |
| 357 | + "# 1d\n", |
| 358 | + "last4_data1d, last4_indptr1d = ga1d.take_from_groups(slice(-4, None))\n", |
| 359 | + "np.testing.assert_equal(\n", |
| 360 | + " last4_data1d,\n", |
| 361 | + " np.array([0, 1, 6, 7, 8, 9])\n", |
| 362 | + ")\n", |
| 363 | + "np.testing.assert_equal(last4_indptr1d, np.array([0, 2, 6]))" |
299 | 364 | ] |
300 | 365 | }, |
301 | 366 | { |
|
308 | 373 | "# Select a specific subset of groups\n", |
309 | 374 | "indptr = np.array([0, 2, 4, 7, 10])\n", |
310 | 375 | "ga2 = GroupedArray(data, indptr)\n", |
311 | | - "subset = ga2.take([0, 2])\n", |
| 376 | + "subset = GroupedArray(*ga2.take([0, 2]))\n", |
312 | 377 | "np.testing.assert_allclose(subset[0].data, ga2[0].data)\n", |
313 | | - "np.testing.assert_allclose(subset[1].data, ga2[2].data)" |
| 378 | + "np.testing.assert_allclose(subset[1].data, ga2[2].data)\n", |
| 379 | + "\n", |
| 380 | + "# 1d\n", |
| 381 | + "ga2_1d = GroupedArray(np.arange(10), indptr)\n", |
| 382 | + "subset1d = GroupedArray(*ga2_1d.take([0, 2]))\n", |
| 383 | + "np.testing.assert_allclose(subset1d[0].data, ga2_1d[0].data)\n", |
| 384 | + "np.testing.assert_allclose(subset1d[1].data, ga2_1d[2].data)" |
314 | 385 | ] |
315 | 386 | }, |
316 | 387 | { |
|
0 commit comments