|
24 | 24 | "import pandas as pd\n", |
25 | 25 | "\n", |
26 | 26 | "from utilsforecast.compat import DataFrame\n", |
27 | | - "from utilsforecast.processing import DataFrameProcessor" |
| 27 | + "from utilsforecast.processing import DataFrameProcessor, group_by" |
28 | 28 | ] |
29 | 29 | }, |
30 | 30 | { |
|
168 | 168 | " if isinstance(df, pd.DataFrame):\n", |
169 | 169 | " sizes = df.groupby(id_col, observed=True).size().values\n", |
170 | 170 | " else:\n", |
171 | | - " try:\n", |
172 | | - " group_sizes = df.group_by(id_col, maintain_order=True).count()\n", |
173 | | - " except AttributeError:\n", |
174 | | - " group_sizes = df.groupby(id_col, maintain_order=True).count()\n", |
| 171 | + " group_sizes = group_by(df, id_col, maintain_order=True).count()\n", |
175 | 172 | " sizes = group_sizes['count'].to_numpy()\n", |
176 | 173 | " \n", |
177 | 174 | " indptr = np.append(0, sizes.cumsum())\n", |
178 | 175 | " proc = DataFrameProcessor(id_col, time_col, target_col)\n", |
179 | | - " data = proc._value_cols_to_numpy(df)\n", |
| 176 | + " data = proc.value_cols_to_numpy(df)\n", |
180 | 177 | " if data.dtype not in (np.float32, np.float64):\n", |
181 | 178 | " data = data.astype(np.float32)\n", |
182 | 179 | " return cls(data, indptr)\n", |
|
232 | 229 | "metadata": {}, |
233 | 230 | "outputs": [], |
234 | 231 | "source": [ |
235 | | - "from fastcore.test import test_eq, test_fail" |
| 232 | + "from fastcore.test import test_eq, test_fail\n", |
| 233 | + "\n", |
| 234 | + "from utilsforecast.data import generate_series" |
236 | 235 | ] |
237 | 236 | }, |
238 | 237 | { |
|
334 | 333 | { |
335 | 334 | "cell_type": "code", |
336 | 335 | "execution_count": null, |
337 | | - "id": "8c28ccf5-0f75-4cae-a0d5-e880037ff3e1", |
338 | | - "metadata": {}, |
339 | | - "outputs": [], |
340 | | - "source": [ |
341 | | - "from utilsforecast.data import generate_series" |
342 | | - ] |
343 | | - }, |
344 | | - { |
345 | | - "cell_type": "code", |
346 | | - "execution_count": null, |
347 | | - "id": "6914dcf3-4d13-42f7-917c-242a47477740", |
| 336 | + "id": "c06325e0-8265-4a61-b936-7fc29d6396be", |
348 | 337 | "metadata": {}, |
349 | 338 | "outputs": [], |
350 | 339 | "source": [ |
| 340 | + "#| polars\n", |
351 | 341 | "# build from df\n", |
352 | 342 | "series_pd = generate_series(10, static_as_categorical=False, engine='pandas')\n", |
353 | | - "series_pl = generate_series(10, static_as_categorical=False, engine='polars')\n", |
354 | 343 | "ga_pd = GroupedArray.from_sorted_df(series_pd, 'unique_id', 'ds', 'y')\n", |
| 344 | + "series_pl = generate_series(10, static_as_categorical=False, engine='polars')\n", |
355 | 345 | "ga_pl = GroupedArray.from_sorted_df(series_pl, 'unique_id', 'ds', 'y')\n", |
356 | 346 | "np.testing.assert_allclose(ga_pd.data, ga_pl.data)\n", |
357 | 347 | "np.testing.assert_equal(ga_pd.indptr, ga_pl.indptr)" |
|
0 commit comments