Skip to content

Commit 5ee76eb

Browse files
authored
add more functions to processing (#27)
1 parent e629e54 commit 5ee76eb

File tree

3 files changed

+433
-34
lines changed

3 files changed

+433
-34
lines changed

nbs/processing.ipynb

Lines changed: 294 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"source": [
3232
"#| export\n",
3333
"import re\n",
34-
"from typing import Dict, List, Optional, Tuple, Union\n",
34+
"from typing import Any, Dict, List, Optional, Tuple, Union\n",
3535
"\n",
3636
"import numpy as np\n",
3737
"import pandas as pd\n",
@@ -358,6 +358,44 @@
358358
" )"
359359
]
360360
},
361+
{
362+
"cell_type": "code",
363+
"execution_count": null,
364+
"id": "757e7421-017d-49f9-bdd0-c59fc7556488",
365+
"metadata": {},
366+
"outputs": [],
367+
"source": [
368+
"#| export\n",
369+
"def match_if_categorical(s1: Union[Series, pd.Index], s2: Series) -> Tuple[Series, Series]:\n",
370+
" if isinstance(s1.dtype, pd.CategoricalDtype):\n",
371+
" if isinstance(s1, pd.Index):\n",
372+
" cat1 = s1.categories\n",
373+
" else:\n",
374+
" cat1 = s1.cat.categories\n",
375+
" if isinstance(s2.dtype, pd.CategoricalDtype):\n",
376+
" cat2 = s2.cat.categories\n",
377+
" else:\n",
378+
" cat2 = s2.unique().astype(cat1.dtype)\n",
379+
" missing = set(cat2) - set(cat1)\n",
380+
" if missing:\n",
381+
" # we assume the original is s1, so we extend its categories\n",
382+
" new_dtype = pd.CategoricalDtype(categories=cat1.tolist() + sorted(missing))\n",
383+
" s1 = s1.astype(new_dtype)\n",
384+
" s2 = s2.astype(new_dtype)\n",
385+
" elif isinstance(s1, pl_Series) and s1.dtype == pl.Categorical:\n",
386+
" with pl.StringCache():\n",
387+
" cat1 = s1.cat.get_categories()\n",
388+
" if s2.dtype == pl.Categorical:\n",
389+
" cat2 = s2.cat.get_categories()\n",
390+
" else:\n",
391+
" cat2 = s2.unique().sort().cast(cat1.dtype)\n",
392+
" # populate cache, keep original categories first\n",
393+
" pl.concat([cat1, cat2]).cast(pl.Categorical)\n",
394+
" s1 = s1.cast(pl.Utf8).cast(pl.Categorical)\n",
395+
" s2 = s2.cast(pl.Utf8).cast(pl.Categorical)\n",
396+
" return s1, s2"
397+
]
398+
},
361399
{
362400
"cell_type": "code",
363401
"execution_count": null,
@@ -369,15 +407,73 @@
369407
"def vertical_concat(dfs: List[DataFrame]) -> DataFrame:\n",
370408
" if not dfs:\n",
371409
" raise ValueError(\"Can't concatenate empty list.\")\n",
410+
" if len(dfs) == 1:\n",
411+
" return dfs\n",
372412
" if isinstance(dfs[0], pd.DataFrame):\n",
373-
" out = pd.concat(dfs)\n",
374-
" elif isinstance(dfs[0], pl_DataFrame):\n",
375-
" out = pl.concat(dfs)\n",
413+
" cat_cols = [c for c, dtype in dfs[0].dtypes.items() if isinstance(dtype, pd.CategoricalDtype)]\n",
414+
" if cat_cols:\n",
415+
" if len(dfs) > 2:\n",
416+
" raise NotImplementedError('Categorical replacement for more than two dataframes')\n",
417+
" assert len(dfs) == 2\n",
418+
" df1, df2 = dfs\n",
419+
" df1 = df1.copy(deep=False)\n",
420+
" df2 = df2.copy(deep=False) \n",
421+
" for col in cat_cols:\n",
422+
" s1, s2 = match_if_categorical(df1[col], df2[col])\n",
423+
" df1[col] = s1\n",
424+
" df2[col] = s2\n",
425+
" dfs = [df1, df2]\n",
426+
" out = pd.concat(dfs).reset_index(drop=True)\n",
376427
" else:\n",
377-
" raise ValueError(f'Got list of unexpected types: {type(dfs[0])}.')\n",
428+
" all_cols = dfs[0].columns\n",
429+
" cat_cols = [all_cols[i] for i, dtype in enumerate(dfs[0].dtypes) if dtype == pl.Categorical]\n",
430+
" if cat_cols:\n",
431+
" if len(dfs) > 2:\n",
432+
" raise NotImplementedError('Categorical replacement for more than two dataframes')\n",
433+
" assert len(dfs) == 2\n",
434+
" df1, df2 = dfs\n",
435+
" for col in cat_cols:\n",
436+
" s1, s2 = match_if_categorical(df1[col], df2[col])\n",
437+
" df1 = df1.with_columns(s1)\n",
438+
" df2 = df2.with_columns(s2)\n",
439+
" dfs = [df1, df2]\n",
440+
" out = pl.concat(dfs)\n",
378441
" return out"
379442
]
380443
},
444+
{
445+
"cell_type": "code",
446+
"execution_count": null,
447+
"id": "a21c0461-3964-4c82-a406-9fb7ea624f23",
448+
"metadata": {},
449+
"outputs": [],
450+
"source": [
451+
"df1 = pd.DataFrame({'x': ['a', 'b', 'c']}, dtype='category')\n",
452+
"df2 = pd.DataFrame({'x': ['f', 'b', 'a']}, dtype='category')\n",
453+
"pd.testing.assert_series_equal(\n",
454+
" vertical_concat([df1,df2])['x'],\n",
455+
" pd.Series(['a', 'b', 'c', 'f', 'b', 'a'], name='x', dtype=pd.CategoricalDtype(categories=['a', 'b', 'c', 'f']))\n",
456+
")"
457+
]
458+
},
459+
{
460+
"cell_type": "code",
461+
"execution_count": null,
462+
"id": "986ab374-90fc-4ba8-b442-797abc63d2de",
463+
"metadata": {},
464+
"outputs": [],
465+
"source": [
466+
"#| polars\n",
467+
"df1 = pl.DataFrame({'x': ['a', 'b', 'c']}, schema={'x': pl.Categorical})\n",
468+
"df2 = pl.DataFrame({'x': ['f', 'b', 'a']}, schema={'x': pl.Categorical})\n",
469+
"out = vertical_concat([df1,df2])['x']\n",
470+
"assert out.series_equal(pl.Series('x', ['a', 'b', 'c', 'f', 'b', 'a']))\n",
471+
"assert out.to_physical().series_equal(pl.Series('x', [0, 1, 2, 3, 1, 0]))\n",
472+
"assert out.cat.get_categories().series_equal(\n",
473+
" pl.Series('x', ['a', 'b', 'c', 'f'])\n",
474+
")"
475+
]
476+
},
381477
{
382478
"cell_type": "code",
383479
"execution_count": null,
@@ -452,11 +548,15 @@
452548
"source": [
453549
"#| export\n",
454550
"def join(\n",
455-
" df1: DataFrame,\n",
456-
" df2: DataFrame,\n",
551+
" df1: Union[DataFrame, Series],\n",
552+
" df2: Union[DataFrame, Series],\n",
457553
" on: Union[str, List[str]],\n",
458554
" how: str = 'inner'\n",
459555
") -> DataFrame:\n",
556+
" if isinstance(df1, (pd.Series, pl_Series)):\n",
557+
" df1 = df1.to_frame()\n",
558+
" if isinstance(df2, (pd.Series, pl_Series)):\n",
559+
" df2 = df2.to_frame()\n",
460560
" if isinstance(df1, pd.DataFrame):\n",
461561
" out = df1.merge(df2, on=on, how=how)\n",
462562
" else:\n",
@@ -502,14 +602,68 @@
502602
"outputs": [],
503603
"source": [
504604
"#| export\n",
505-
"def sort(df: DataFrame, by: Union[str, List[str]]) -> DataFrame:\n",
605+
"def sort(df: DataFrame, by: Optional[Union[str, List[str]]] = None) -> DataFrame:\n",
506606
" if isinstance(df, pd.DataFrame):\n",
507-
" out = df.sort_values(by)\n",
508-
" else:\n",
607+
" out = df.sort_values(by).reset_index(drop=True)\n",
608+
" elif isinstance(df, (pd.Series, pd.Index)):\n",
609+
" out = df.sort_values()\n",
610+
" if isinstance(out, pd.Series):\n",
611+
" out = out.reset_index(drop=True)\n",
612+
" elif isinstance(df, pl_DataFrame):\n",
509613
" out = df.sort(by)\n",
614+
" else:\n",
615+
" out = df.sort()\n",
510616
" return out"
511617
]
512618
},
619+
{
620+
"cell_type": "code",
621+
"execution_count": null,
622+
"id": "c14e0b1c-3770-4d8d-a8d0-63ed2bdf147c",
623+
"metadata": {},
624+
"outputs": [],
625+
"source": [
626+
"pd.testing.assert_frame_equal(\n",
627+
" sort(pd.DataFrame({'x': [3, 1, 2]}), 'x'),\n",
628+
" pd.DataFrame({'x': [1, 2, 3]})\n",
629+
")\n",
630+
"pd.testing.assert_frame_equal(\n",
631+
" sort(pd.DataFrame({'x': [3, 1, 2]}), ['x']),\n",
632+
" pd.DataFrame({'x': [1, 2, 3]})\n",
633+
")\n",
634+
"pd.testing.assert_series_equal(\n",
635+
" sort(pd.Series([3, 1, 2])),\n",
636+
" pd.Series([1, 2, 3])\n",
637+
")\n",
638+
"pd.testing.assert_index_equal(\n",
639+
" sort(pd.Index([3, 1, 2])),\n",
640+
" pd.Index([1, 2, 3])\n",
641+
")"
642+
]
643+
},
644+
{
645+
"cell_type": "code",
646+
"execution_count": null,
647+
"id": "43e1c151-1f81-442d-9f32-d88ca85a5e73",
648+
"metadata": {},
649+
"outputs": [],
650+
"source": [
651+
"#| polars\n",
652+
"# TODO: replace with pl.testing.assert_frame_equal when it's released\n",
653+
"pd.testing.assert_frame_equal(\n",
654+
" sort(pl.DataFrame({'x': [3, 1, 2]}), 'x').to_pandas(),\n",
655+
" pd.DataFrame({'x': [1, 2, 3]}),\n",
656+
")\n",
657+
"pd.testing.assert_frame_equal(\n",
658+
" sort(pl.DataFrame({'x': [3, 1, 2]}), ['x']).to_pandas(),\n",
659+
" pd.DataFrame({'x': [1, 2, 3]}),\n",
660+
")\n",
661+
"pd.testing.assert_series_equal(\n",
662+
" sort(pl.Series('x', [3, 1, 2])).to_pandas(),\n",
663+
" pd.Series([1, 2, 3], name='x')\n",
664+
")"
665+
]
666+
},
513667
{
514668
"cell_type": "code",
515669
"execution_count": null,
@@ -557,6 +711,49 @@
557711
" return out"
558712
]
559713
},
714+
{
715+
"cell_type": "code",
716+
"execution_count": null,
717+
"id": "a2e3ff2c-9e70-46b3-9bf0-bbfcd339d9ba",
718+
"metadata": {},
719+
"outputs": [],
720+
"source": [
721+
"#| export\n",
722+
"def group_by_agg(df: DataFrame, by, aggs, maintain_order=False) -> DataFrame:\n",
723+
" if isinstance(df, pd.DataFrame):\n",
724+
" out = group_by(df, by, maintain_order).agg(aggs).reset_index()\n",
725+
" else:\n",
726+
" out = group_by(df, by, maintain_order).agg(*[getattr(pl.col(c), agg)() for c, agg in aggs.items()])\n",
727+
" return out"
728+
]
729+
},
730+
{
731+
"cell_type": "code",
732+
"execution_count": null,
733+
"id": "d9f92cd4-d3f7-4de1-b438-c3c5891c3343",
734+
"metadata": {},
735+
"outputs": [],
736+
"source": [
737+
"pd.testing.assert_frame_equal(\n",
738+
" group_by_agg(pd.DataFrame({'x': [1, 1, 2], 'y': [1, 1, 1]}), 'x', {'y': 'sum'}),\n",
739+
" pd.DataFrame({'x': [1, 2], 'y': [2, 1]})\n",
740+
")"
741+
]
742+
},
743+
{
744+
"cell_type": "code",
745+
"execution_count": null,
746+
"id": "329cfc66-a218-498e-b674-96491f47a3e1",
747+
"metadata": {},
748+
"outputs": [],
749+
"source": [
750+
"#| polars\n",
751+
"pd.testing.assert_frame_equal(\n",
752+
" group_by_agg(pl.DataFrame({'x': [1, 1, 2], 'y': [1, 1, 1]}), 'x', {'y': 'sum'}, maintain_order=True).to_pandas(),\n",
753+
" pd.DataFrame({'x': [1, 2], 'y': [2, 1]})\n",
754+
")"
755+
]
756+
},
560757
{
561758
"cell_type": "code",
562759
"execution_count": null,
@@ -594,6 +791,93 @@
594791
"np.testing.assert_equal(is_in(pl.Series([1, 2, 3]), [1]), np.array([True, False, False]))"
595792
]
596793
},
794+
{
795+
"cell_type": "code",
796+
"execution_count": null,
797+
"id": "9717022e-2c6f-47dc-8b19-da069341b094",
798+
"metadata": {},
799+
"outputs": [],
800+
"source": [
801+
"#| export\n",
802+
"def between(s: Series, lower: Series, upper: Series) -> Series:\n",
803+
" if isinstance(s, pd.Series):\n",
804+
" out = s.between(lower, upper)\n",
805+
" else:\n",
806+
" out = s.is_between(lower, upper)\n",
807+
" return out"
808+
]
809+
},
810+
{
811+
"cell_type": "code",
812+
"execution_count": null,
813+
"id": "dca138b4-e771-4b8e-aa54-35dc37802d78",
814+
"metadata": {},
815+
"outputs": [],
816+
"source": [
817+
"np.testing.assert_equal(\n",
818+
" between(pd.Series([1, 2, 3]), pd.Series([0, 1, 4]), pd.Series([4, 1, 2])),\n",
819+
" np.array([True, False, False]),\n",
820+
")"
821+
]
822+
},
823+
{
824+
"cell_type": "code",
825+
"execution_count": null,
826+
"id": "c6c773bf-fe23-4428-84f6-c5afaefdad06",
827+
"metadata": {},
828+
"outputs": [],
829+
"source": [
830+
"#| polars\n",
831+
"np.testing.assert_equal(\n",
832+
" between(pl.Series([1, 2, 3]), pl.Series([0, 1, 4]), pl.Series([4, 1, 2])),\n",
833+
" np.array([True, False, False]),\n",
834+
")"
835+
]
836+
},
837+
{
838+
"cell_type": "code",
839+
"execution_count": null,
840+
"id": "667302bf-3b54-4298-8fcc-82cd6b12fb73",
841+
"metadata": {},
842+
"outputs": [],
843+
"source": [
844+
"#| export\n",
845+
"def fill_null(df: DataFrame, mapping: Dict[str, Any]) -> DataFrame:\n",
846+
" if isinstance(df, pd.DataFrame):\n",
847+
" out = df.fillna(mapping)\n",
848+
" else:\n",
849+
" out = df.with_columns(*[pl.col(col).fill_null(v) for col, v in mapping.items()])\n",
850+
" return out"
851+
]
852+
},
853+
{
854+
"cell_type": "code",
855+
"execution_count": null,
856+
"id": "74993c58-0886-4290-ab90-8065651886c5",
857+
"metadata": {},
858+
"outputs": [],
859+
"source": [
860+
"pd.testing.assert_frame_equal(\n",
861+
" fill_null(pd.DataFrame({'x': [1, np.nan], 'y': [np.nan, 2]}), {'x': 2, 'y': 1}),\n",
862+
" pd.DataFrame({'x': [1, 2], 'y': [1, 2]}, dtype='float64')\n",
863+
")"
864+
]
865+
},
866+
{
867+
"cell_type": "code",
868+
"execution_count": null,
869+
"id": "ec1d835a-f1dc-4c1a-be2d-e7dd5b9895ad",
870+
"metadata": {},
871+
"outputs": [],
872+
"source": [
873+
"#| polars\n",
874+
"# TODO: replace with pl.testing.assert_frame_equal when it's released\n",
875+
"pd.testing.assert_frame_equal(\n",
876+
" fill_null(pl.DataFrame({'x': [1, None], 'y': [None, 2]}), {'x': 2, 'y': 1}).to_pandas(),\n",
877+
" pd.DataFrame({'x': [1, 2], 'y': [1, 2]})\n",
878+
")"
879+
]
880+
},
597881
{
598882
"cell_type": "code",
599883
"execution_count": null,

utilsforecast/_modidx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,20 @@
6868
'utilsforecast/processing.py'),
6969
'utilsforecast.processing.assign_columns': ( 'processing.html#assign_columns',
7070
'utilsforecast/processing.py'),
71+
'utilsforecast.processing.between': ('processing.html#between', 'utilsforecast/processing.py'),
7172
'utilsforecast.processing.copy_if_pandas': ( 'processing.html#copy_if_pandas',
7273
'utilsforecast/processing.py'),
7374
'utilsforecast.processing.counts_by_id': ( 'processing.html#counts_by_id',
7475
'utilsforecast/processing.py'),
7576
'utilsforecast.processing.drop_index_if_pandas': ( 'processing.html#drop_index_if_pandas',
7677
'utilsforecast/processing.py'),
78+
'utilsforecast.processing.fill_null': ( 'processing.html#fill_null',
79+
'utilsforecast/processing.py'),
7780
'utilsforecast.processing.filter_with_mask': ( 'processing.html#filter_with_mask',
7881
'utilsforecast/processing.py'),
7982
'utilsforecast.processing.group_by': ('processing.html#group_by', 'utilsforecast/processing.py'),
83+
'utilsforecast.processing.group_by_agg': ( 'processing.html#group_by_agg',
84+
'utilsforecast/processing.py'),
8085
'utilsforecast.processing.horizontal_concat': ( 'processing.html#horizontal_concat',
8186
'utilsforecast/processing.py'),
8287
'utilsforecast.processing.is_in': ('processing.html#is_in', 'utilsforecast/processing.py'),
@@ -85,6 +90,8 @@
8590
'utilsforecast/processing.py'),
8691
'utilsforecast.processing.is_none': ('processing.html#is_none', 'utilsforecast/processing.py'),
8792
'utilsforecast.processing.join': ('processing.html#join', 'utilsforecast/processing.py'),
93+
'utilsforecast.processing.match_if_categorical': ( 'processing.html#match_if_categorical',
94+
'utilsforecast/processing.py'),
8895
'utilsforecast.processing.maybe_compute_sort_indices': ( 'processing.html#maybe_compute_sort_indices',
8996
'utilsforecast/processing.py'),
9097
'utilsforecast.processing.offset_dates': ( 'processing.html#offset_dates',

0 commit comments

Comments
 (0)