|
16 | 16 | DataFrame, |
17 | 17 | Index, |
18 | 18 | Series, |
| 19 | + to_datetime, |
| 20 | + to_timedelta, |
19 | 21 | ) |
20 | 22 | import pandas._testing as tm |
21 | 23 |
|
@@ -511,37 +513,84 @@ def test_rank_string_dtype(self, string_dtype_no_object): |
511 | 513 | "double[pyarrow]", |
512 | 514 | marks=td.skip_if_no("pyarrow"), |
513 | 515 | ), |
514 | | - ("min", "Int32", "Float64"), |
515 | | - ("min", "Float32", "Float64"), |
| 516 | + ("min", "Int32", "UInt64"), |
| 517 | + ("min", "Float32", "UInt64"), |
516 | 518 | pytest.param( |
517 | 519 | "min", |
518 | 520 | "int32[pyarrow]", |
519 | | - "double[pyarrow]", |
| 521 | + "uint64[pyarrow]", |
520 | 522 | marks=td.skip_if_no("pyarrow"), |
521 | 523 | ), |
522 | 524 | ], |
523 | 525 | ) |
524 | 526 | def test_rank_extension_array_dtype(self, method, og_dtype, expected_dtype): |
525 | 527 | # GH#52829 |
526 | | - result = DataFrame([4, 89, 33], dtype=og_dtype).rank() |
| 528 | + result = DataFrame([4, 89, 33], dtype=og_dtype).rank(method=method) |
527 | 529 | if method == "average": |
528 | 530 | expected = DataFrame([1.0, 3.0, 2.0], dtype=expected_dtype) |
529 | 531 | else: |
530 | 532 | expected = DataFrame([1, 3, 2], dtype=expected_dtype) |
531 | 533 | tm.assert_frame_equal(result, expected) |
532 | 534 |
|
533 | 535 | def test_rank_mixed_extension_array_dtype(self): |
| 536 | + # GH#52829 |
534 | 537 | pytest.importorskip("pyarrow") |
535 | 538 | result = DataFrame( |
536 | 539 | { |
537 | 540 | "base": Series([4, 5, 6]), |
538 | | - "extension": Series([7, 8, 9], dtype="int32[pyarrow]"), |
| 541 | + "pyarrow": Series([7, 8, 9], dtype="int32[pyarrow]"), |
539 | 542 | } |
540 | 543 | ).rank(method="min") |
541 | 544 | expected = DataFrame( |
542 | 545 | { |
543 | 546 | "base": Series([1.0, 2.0, 3.0], dtype="float64"), |
544 | | - "extension": Series([1, 2, 3], dtype="uint64[pyarrow]"), |
| 547 | + "pyarrow": Series([1, 2, 3], dtype="uint64[pyarrow]"), |
545 | 548 | } |
546 | 549 | ) |
547 | 550 | tm.assert_frame_equal(result, expected) |
| 551 | + |
| 552 | + def test_2d_extension_array_datetime(self): |
| 553 | + # GH#52829 |
| 554 | + df = DataFrame( |
| 555 | + { |
| 556 | + "year": to_datetime(["2012-1-1", "2013-1-1", "2014-1-1"]), |
| 557 | + "week": to_datetime(["2012-1-2", "2012-1-9", "2012-1-16"]), |
| 558 | + "day": to_datetime(["2012-1-3", "2012-1-4", "2012-1-5"]), |
| 559 | + } |
| 560 | + ) |
| 561 | + axis0_expected = DataFrame( |
| 562 | + {"year": [1.0, 2.0, 3.0], "week": [1.0, 2.0, 3.0], "day": [1.0, 2.0, 3.0]} |
| 563 | + ) |
| 564 | + axis1_expected = DataFrame( |
| 565 | + {"year": [1.0, 3.0, 3.0], "week": [2.0, 2.0, 2.0], "day": [3.0, 1.0, 1.0]} |
| 566 | + ) |
| 567 | + tm.assert_frame_equal(df.rank(), axis0_expected) |
| 568 | + tm.assert_frame_equal(df.rank(1), axis1_expected) |
| 569 | + |
| 570 | + def test_2d_extension_array_timedelta(self): |
| 571 | + # GH#52829 |
| 572 | + df = DataFrame( |
| 573 | + { |
| 574 | + "day": to_timedelta(["0 days", "1 day", "2 days"]), |
| 575 | + "hourly": to_timedelta(["23 hours", "24 hours", "25 hours"]), |
| 576 | + "minute": to_timedelta( |
| 577 | + ["1439 minutes", "1440 minutes", "1441 minutes"] |
| 578 | + ), |
| 579 | + } |
| 580 | + ) |
| 581 | + axis0_expected = DataFrame( |
| 582 | + { |
| 583 | + "day": [1.0, 2.0, 3.0], |
| 584 | + "hourly": [1.0, 2.0, 3.0], |
| 585 | + "minute": [1.0, 2.0, 3.0], |
| 586 | + } |
| 587 | + ) |
| 588 | + axis1_expected = DataFrame( |
| 589 | + { |
| 590 | + "day": [1.0, 2.0, 3.0], |
| 591 | + "hourly": [2.0, 2.0, 2.0], |
| 592 | + "minute": [3.0, 2.0, 1.0], |
| 593 | + } |
| 594 | + ) |
| 595 | + tm.assert_frame_equal(df.rank(), axis0_expected) |
| 596 | + tm.assert_frame_equal(df.rank(1), axis1_expected) |
0 commit comments