|
31 | 31 | "source": [ |
32 | 32 | "#| export\n", |
33 | 33 | "import re\n", |
34 | | - "from typing import Dict, List, Optional, Tuple, Union\n", |
| 34 | + "from typing import Any, Dict, List, Optional, Tuple, Union\n", |
35 | 35 | "\n", |
36 | 36 | "import numpy as np\n", |
37 | 37 | "import pandas as pd\n", |
|
358 | 358 | " )" |
359 | 359 | ] |
360 | 360 | }, |
| 361 | + { |
| 362 | + "cell_type": "code", |
| 363 | + "execution_count": null, |
| 364 | + "id": "757e7421-017d-49f9-bdd0-c59fc7556488", |
| 365 | + "metadata": {}, |
| 366 | + "outputs": [], |
| 367 | + "source": [ |
| 368 | + "#| export\n", |
| 369 | + "def match_if_categorical(s1: Union[Series, pd.Index], s2: Series) -> Tuple[Series, Series]:\n", |
| 370 | + " if isinstance(s1.dtype, pd.CategoricalDtype):\n", |
| 371 | + " if isinstance(s1, pd.Index):\n", |
| 372 | + " cat1 = s1.categories\n", |
| 373 | + " else:\n", |
| 374 | + " cat1 = s1.cat.categories\n", |
| 375 | + " if isinstance(s2.dtype, pd.CategoricalDtype):\n", |
| 376 | + " cat2 = s2.cat.categories\n", |
| 377 | + " else:\n", |
| 378 | + " cat2 = s2.unique().astype(cat1.dtype)\n", |
| 379 | + " missing = set(cat2) - set(cat1)\n", |
| 380 | + " if missing:\n", |
| 381 | + " # we assume the original is s1, so we extend its categories\n", |
| 382 | + " new_dtype = pd.CategoricalDtype(categories=cat1.tolist() + sorted(missing))\n", |
| 383 | + " s1 = s1.astype(new_dtype)\n", |
| 384 | + " s2 = s2.astype(new_dtype)\n", |
| 385 | + " elif isinstance(s1, pl_Series) and s1.dtype == pl.Categorical:\n", |
| 386 | + " with pl.StringCache():\n", |
| 387 | + " cat1 = s1.cat.get_categories()\n", |
| 388 | + " if s2.dtype == pl.Categorical:\n", |
| 389 | + " cat2 = s2.cat.get_categories()\n", |
| 390 | + " else:\n", |
| 391 | + " cat2 = s2.unique().sort().cast(cat1.dtype)\n", |
| 392 | + " # populate cache, keep original categories first\n", |
| 393 | + " pl.concat([cat1, cat2]).cast(pl.Categorical)\n", |
| 394 | + " s1 = s1.cast(pl.Utf8).cast(pl.Categorical)\n", |
| 395 | + " s2 = s2.cast(pl.Utf8).cast(pl.Categorical)\n", |
| 396 | + " return s1, s2" |
| 397 | + ] |
| 398 | + }, |
361 | 399 | { |
362 | 400 | "cell_type": "code", |
363 | 401 | "execution_count": null, |
|
369 | 407 | "def vertical_concat(dfs: List[DataFrame]) -> DataFrame:\n", |
370 | 408 | " if not dfs:\n", |
371 | 409 | " raise ValueError(\"Can't concatenate empty list.\")\n", |
| 410 | + " if len(dfs) == 1:\n", |
| 411 | + " return dfs\n", |
372 | 412 | " if isinstance(dfs[0], pd.DataFrame):\n", |
373 | | - " out = pd.concat(dfs)\n", |
374 | | - " elif isinstance(dfs[0], pl_DataFrame):\n", |
375 | | - " out = pl.concat(dfs)\n", |
| 413 | + " cat_cols = [c for c, dtype in dfs[0].dtypes.items() if isinstance(dtype, pd.CategoricalDtype)]\n", |
| 414 | + " if cat_cols:\n", |
| 415 | + " if len(dfs) > 2:\n", |
| 416 | + " raise NotImplementedError('Categorical replacement for more than two dataframes')\n", |
| 417 | + " assert len(dfs) == 2\n", |
| 418 | + " df1, df2 = dfs\n", |
| 419 | + " df1 = df1.copy(deep=False)\n", |
| 420 | + " df2 = df2.copy(deep=False) \n", |
| 421 | + " for col in cat_cols:\n", |
| 422 | + " s1, s2 = match_if_categorical(df1[col], df2[col])\n", |
| 423 | + " df1[col] = s1\n", |
| 424 | + " df2[col] = s2\n", |
| 425 | + " dfs = [df1, df2]\n", |
| 426 | + " out = pd.concat(dfs).reset_index(drop=True)\n", |
376 | 427 | " else:\n", |
377 | | - " raise ValueError(f'Got list of unexpected types: {type(dfs[0])}.')\n", |
| 428 | + " all_cols = dfs[0].columns\n", |
| 429 | + " cat_cols = [all_cols[i] for i, dtype in enumerate(dfs[0].dtypes) if dtype == pl.Categorical]\n", |
| 430 | + " if cat_cols:\n", |
| 431 | + " if len(dfs) > 2:\n", |
| 432 | + " raise NotImplementedError('Categorical replacement for more than two dataframes')\n", |
| 433 | + " assert len(dfs) == 2\n", |
| 434 | + " df1, df2 = dfs\n", |
| 435 | + " for col in cat_cols:\n", |
| 436 | + " s1, s2 = match_if_categorical(df1[col], df2[col])\n", |
| 437 | + " df1 = df1.with_columns(s1)\n", |
| 438 | + " df2 = df2.with_columns(s2)\n", |
| 439 | + " dfs = [df1, df2]\n", |
| 440 | + " out = pl.concat(dfs)\n", |
378 | 441 | " return out" |
379 | 442 | ] |
380 | 443 | }, |
| 444 | + { |
| 445 | + "cell_type": "code", |
| 446 | + "execution_count": null, |
| 447 | + "id": "a21c0461-3964-4c82-a406-9fb7ea624f23", |
| 448 | + "metadata": {}, |
| 449 | + "outputs": [], |
| 450 | + "source": [ |
| 451 | + "df1 = pd.DataFrame({'x': ['a', 'b', 'c']}, dtype='category')\n", |
| 452 | + "df2 = pd.DataFrame({'x': ['f', 'b', 'a']}, dtype='category')\n", |
| 453 | + "pd.testing.assert_series_equal(\n", |
| 454 | + " vertical_concat([df1,df2])['x'],\n", |
| 455 | + " pd.Series(['a', 'b', 'c', 'f', 'b', 'a'], name='x', dtype=pd.CategoricalDtype(categories=['a', 'b', 'c', 'f']))\n", |
| 456 | + ")" |
| 457 | + ] |
| 458 | + }, |
| 459 | + { |
| 460 | + "cell_type": "code", |
| 461 | + "execution_count": null, |
| 462 | + "id": "986ab374-90fc-4ba8-b442-797abc63d2de", |
| 463 | + "metadata": {}, |
| 464 | + "outputs": [], |
| 465 | + "source": [ |
| 466 | + "#| polars\n", |
| 467 | + "df1 = pl.DataFrame({'x': ['a', 'b', 'c']}, schema={'x': pl.Categorical})\n", |
| 468 | + "df2 = pl.DataFrame({'x': ['f', 'b', 'a']}, schema={'x': pl.Categorical})\n", |
| 469 | + "out = vertical_concat([df1,df2])['x']\n", |
| 470 | + "assert out.series_equal(pl.Series('x', ['a', 'b', 'c', 'f', 'b', 'a']))\n", |
| 471 | + "assert out.to_physical().series_equal(pl.Series('x', [0, 1, 2, 3, 1, 0]))\n", |
| 472 | + "assert out.cat.get_categories().series_equal(\n", |
| 473 | + " pl.Series('x', ['a', 'b', 'c', 'f'])\n", |
| 474 | + ")" |
| 475 | + ] |
| 476 | + }, |
381 | 477 | { |
382 | 478 | "cell_type": "code", |
383 | 479 | "execution_count": null, |
|
452 | 548 | "source": [ |
453 | 549 | "#| export\n", |
454 | 550 | "def join(\n", |
455 | | - " df1: DataFrame,\n", |
456 | | - " df2: DataFrame,\n", |
| 551 | + " df1: Union[DataFrame, Series],\n", |
| 552 | + " df2: Union[DataFrame, Series],\n", |
457 | 553 | " on: Union[str, List[str]],\n", |
458 | 554 | " how: str = 'inner'\n", |
459 | 555 | ") -> DataFrame:\n", |
| 556 | + " if isinstance(df1, (pd.Series, pl_Series)):\n", |
| 557 | + " df1 = df1.to_frame()\n", |
| 558 | + " if isinstance(df2, (pd.Series, pl_Series)):\n", |
| 559 | + " df2 = df2.to_frame()\n", |
460 | 560 | " if isinstance(df1, pd.DataFrame):\n", |
461 | 561 | " out = df1.merge(df2, on=on, how=how)\n", |
462 | 562 | " else:\n", |
|
502 | 602 | "outputs": [], |
503 | 603 | "source": [ |
504 | 604 | "#| export\n", |
505 | | - "def sort(df: DataFrame, by: Union[str, List[str]]) -> DataFrame:\n", |
| 605 | + "def sort(df: DataFrame, by: Optional[Union[str, List[str]]] = None) -> DataFrame:\n", |
506 | 606 | " if isinstance(df, pd.DataFrame):\n", |
507 | | - " out = df.sort_values(by)\n", |
508 | | - " else:\n", |
| 607 | + " out = df.sort_values(by).reset_index(drop=True)\n", |
| 608 | + " elif isinstance(df, (pd.Series, pd.Index)):\n", |
| 609 | + " out = df.sort_values()\n", |
| 610 | + " if isinstance(out, pd.Series):\n", |
| 611 | + " out = out.reset_index(drop=True)\n", |
| 612 | + " elif isinstance(df, pl_DataFrame):\n", |
509 | 613 | " out = df.sort(by)\n", |
| 614 | + " else:\n", |
| 615 | + " out = df.sort()\n", |
510 | 616 | " return out" |
511 | 617 | ] |
512 | 618 | }, |
| 619 | + { |
| 620 | + "cell_type": "code", |
| 621 | + "execution_count": null, |
| 622 | + "id": "c14e0b1c-3770-4d8d-a8d0-63ed2bdf147c", |
| 623 | + "metadata": {}, |
| 624 | + "outputs": [], |
| 625 | + "source": [ |
| 626 | + "pd.testing.assert_frame_equal(\n", |
| 627 | + " sort(pd.DataFrame({'x': [3, 1, 2]}), 'x'),\n", |
| 628 | + " pd.DataFrame({'x': [1, 2, 3]})\n", |
| 629 | + ")\n", |
| 630 | + "pd.testing.assert_frame_equal(\n", |
| 631 | + " sort(pd.DataFrame({'x': [3, 1, 2]}), ['x']),\n", |
| 632 | + " pd.DataFrame({'x': [1, 2, 3]})\n", |
| 633 | + ")\n", |
| 634 | + "pd.testing.assert_series_equal(\n", |
| 635 | + " sort(pd.Series([3, 1, 2])),\n", |
| 636 | + " pd.Series([1, 2, 3])\n", |
| 637 | + ")\n", |
| 638 | + "pd.testing.assert_index_equal(\n", |
| 639 | + " sort(pd.Index([3, 1, 2])),\n", |
| 640 | + " pd.Index([1, 2, 3])\n", |
| 641 | + ")" |
| 642 | + ] |
| 643 | + }, |
| 644 | + { |
| 645 | + "cell_type": "code", |
| 646 | + "execution_count": null, |
| 647 | + "id": "43e1c151-1f81-442d-9f32-d88ca85a5e73", |
| 648 | + "metadata": {}, |
| 649 | + "outputs": [], |
| 650 | + "source": [ |
| 651 | + "#| polars\n", |
| 652 | + "# TODO: replace with pl.testing.assert_frame_equal when it's released\n", |
| 653 | + "pd.testing.assert_frame_equal(\n", |
| 654 | + " sort(pl.DataFrame({'x': [3, 1, 2]}), 'x').to_pandas(),\n", |
| 655 | + " pd.DataFrame({'x': [1, 2, 3]}),\n", |
| 656 | + ")\n", |
| 657 | + "pd.testing.assert_frame_equal(\n", |
| 658 | + " sort(pl.DataFrame({'x': [3, 1, 2]}), ['x']).to_pandas(),\n", |
| 659 | + " pd.DataFrame({'x': [1, 2, 3]}),\n", |
| 660 | + ")\n", |
| 661 | + "pd.testing.assert_series_equal(\n", |
| 662 | + " sort(pl.Series('x', [3, 1, 2])).to_pandas(),\n", |
| 663 | + " pd.Series([1, 2, 3], name='x')\n", |
| 664 | + ")" |
| 665 | + ] |
| 666 | + }, |
513 | 667 | { |
514 | 668 | "cell_type": "code", |
515 | 669 | "execution_count": null, |
|
557 | 711 | " return out" |
558 | 712 | ] |
559 | 713 | }, |
| 714 | + { |
| 715 | + "cell_type": "code", |
| 716 | + "execution_count": null, |
| 717 | + "id": "a2e3ff2c-9e70-46b3-9bf0-bbfcd339d9ba", |
| 718 | + "metadata": {}, |
| 719 | + "outputs": [], |
| 720 | + "source": [ |
| 721 | + "#| export\n", |
| 722 | + "def group_by_agg(df: DataFrame, by, aggs, maintain_order=False) -> DataFrame:\n", |
| 723 | + " if isinstance(df, pd.DataFrame):\n", |
| 724 | + " out = group_by(df, by, maintain_order).agg(aggs).reset_index()\n", |
| 725 | + " else:\n", |
| 726 | + " out = group_by(df, by, maintain_order).agg(*[getattr(pl.col(c), agg)() for c, agg in aggs.items()])\n", |
| 727 | + " return out" |
| 728 | + ] |
| 729 | + }, |
| 730 | + { |
| 731 | + "cell_type": "code", |
| 732 | + "execution_count": null, |
| 733 | + "id": "d9f92cd4-d3f7-4de1-b438-c3c5891c3343", |
| 734 | + "metadata": {}, |
| 735 | + "outputs": [], |
| 736 | + "source": [ |
| 737 | + "pd.testing.assert_frame_equal(\n", |
| 738 | + " group_by_agg(pd.DataFrame({'x': [1, 1, 2], 'y': [1, 1, 1]}), 'x', {'y': 'sum'}),\n", |
| 739 | + " pd.DataFrame({'x': [1, 2], 'y': [2, 1]})\n", |
| 740 | + ")" |
| 741 | + ] |
| 742 | + }, |
| 743 | + { |
| 744 | + "cell_type": "code", |
| 745 | + "execution_count": null, |
| 746 | + "id": "329cfc66-a218-498e-b674-96491f47a3e1", |
| 747 | + "metadata": {}, |
| 748 | + "outputs": [], |
| 749 | + "source": [ |
| 750 | + "#| polars\n", |
| 751 | + "pd.testing.assert_frame_equal(\n", |
| 752 | + " group_by_agg(pl.DataFrame({'x': [1, 1, 2], 'y': [1, 1, 1]}), 'x', {'y': 'sum'}, maintain_order=True).to_pandas(),\n", |
| 753 | + " pd.DataFrame({'x': [1, 2], 'y': [2, 1]})\n", |
| 754 | + ")" |
| 755 | + ] |
| 756 | + }, |
560 | 757 | { |
561 | 758 | "cell_type": "code", |
562 | 759 | "execution_count": null, |
|
594 | 791 | "np.testing.assert_equal(is_in(pl.Series([1, 2, 3]), [1]), np.array([True, False, False]))" |
595 | 792 | ] |
596 | 793 | }, |
| 794 | + { |
| 795 | + "cell_type": "code", |
| 796 | + "execution_count": null, |
| 797 | + "id": "9717022e-2c6f-47dc-8b19-da069341b094", |
| 798 | + "metadata": {}, |
| 799 | + "outputs": [], |
| 800 | + "source": [ |
| 801 | + "#| export\n", |
| 802 | + "def between(s: Series, lower: Series, upper: Series) -> Series:\n", |
| 803 | + " if isinstance(s, pd.Series):\n", |
| 804 | + " out = s.between(lower, upper)\n", |
| 805 | + " else:\n", |
| 806 | + " out = s.is_between(lower, upper)\n", |
| 807 | + " return out" |
| 808 | + ] |
| 809 | + }, |
| 810 | + { |
| 811 | + "cell_type": "code", |
| 812 | + "execution_count": null, |
| 813 | + "id": "dca138b4-e771-4b8e-aa54-35dc37802d78", |
| 814 | + "metadata": {}, |
| 815 | + "outputs": [], |
| 816 | + "source": [ |
| 817 | + "np.testing.assert_equal(\n", |
| 818 | + " between(pd.Series([1, 2, 3]), pd.Series([0, 1, 4]), pd.Series([4, 1, 2])),\n", |
| 819 | + " np.array([True, False, False]),\n", |
| 820 | + ")" |
| 821 | + ] |
| 822 | + }, |
| 823 | + { |
| 824 | + "cell_type": "code", |
| 825 | + "execution_count": null, |
| 826 | + "id": "c6c773bf-fe23-4428-84f6-c5afaefdad06", |
| 827 | + "metadata": {}, |
| 828 | + "outputs": [], |
| 829 | + "source": [ |
| 830 | + "#| polars\n", |
| 831 | + "np.testing.assert_equal(\n", |
| 832 | + " between(pl.Series([1, 2, 3]), pl.Series([0, 1, 4]), pl.Series([4, 1, 2])),\n", |
| 833 | + " np.array([True, False, False]),\n", |
| 834 | + ")" |
| 835 | + ] |
| 836 | + }, |
| 837 | + { |
| 838 | + "cell_type": "code", |
| 839 | + "execution_count": null, |
| 840 | + "id": "667302bf-3b54-4298-8fcc-82cd6b12fb73", |
| 841 | + "metadata": {}, |
| 842 | + "outputs": [], |
| 843 | + "source": [ |
| 844 | + "#| export\n", |
| 845 | + "def fill_null(df: DataFrame, mapping: Dict[str, Any]) -> DataFrame:\n", |
| 846 | + " if isinstance(df, pd.DataFrame):\n", |
| 847 | + " out = df.fillna(mapping)\n", |
| 848 | + " else:\n", |
| 849 | + " out = df.with_columns(*[pl.col(col).fill_null(v) for col, v in mapping.items()])\n", |
| 850 | + " return out" |
| 851 | + ] |
| 852 | + }, |
| 853 | + { |
| 854 | + "cell_type": "code", |
| 855 | + "execution_count": null, |
| 856 | + "id": "74993c58-0886-4290-ab90-8065651886c5", |
| 857 | + "metadata": {}, |
| 858 | + "outputs": [], |
| 859 | + "source": [ |
| 860 | + "pd.testing.assert_frame_equal(\n", |
| 861 | + " fill_null(pd.DataFrame({'x': [1, np.nan], 'y': [np.nan, 2]}), {'x': 2, 'y': 1}),\n", |
| 862 | + " pd.DataFrame({'x': [1, 2], 'y': [1, 2]}, dtype='float64')\n", |
| 863 | + ")" |
| 864 | + ] |
| 865 | + }, |
| 866 | + { |
| 867 | + "cell_type": "code", |
| 868 | + "execution_count": null, |
| 869 | + "id": "ec1d835a-f1dc-4c1a-be2d-e7dd5b9895ad", |
| 870 | + "metadata": {}, |
| 871 | + "outputs": [], |
| 872 | + "source": [ |
| 873 | + "#| polars\n", |
| 874 | + "# TODO: replace with pl.testing.assert_frame_equal when it's released\n", |
| 875 | + "pd.testing.assert_frame_equal(\n", |
| 876 | + " fill_null(pl.DataFrame({'x': [1, None], 'y': [None, 2]}), {'x': 2, 'y': 1}).to_pandas(),\n", |
| 877 | + " pd.DataFrame({'x': [1, 2], 'y': [1, 2]})\n", |
| 878 | + ")" |
| 879 | + ] |
| 880 | + }, |
597 | 881 | { |
598 | 882 | "cell_type": "code", |
599 | 883 | "execution_count": null, |
|
0 commit comments