Skip to content

Commit fce155a

Browse files
committed
2d extension arrays
1 parent 15922e8 commit fce155a

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

pandas/core/arrays/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,9 +2415,6 @@ def _rank(
24152415
See Series.rank.__doc__.
24162416
"""
24172417

2418-
if axis != 0:
2419-
raise NotImplementedError
2420-
24212418
return rank(
24222419
self,
24232420
axis=axis,

pandas/core/generic.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9277,33 +9277,27 @@ def rank(
92779277
raise ValueError(msg)
92789278

92799279
def ranker(blk_values):
9280-
if isinstance(blk_values, ExtensionArray) and blk_values.ndim == 1:
9280+
if axis_int == 0:
9281+
blk_values = blk_values.T
9282+
if isinstance(blk_values, ExtensionArray):
92819283
ranks = blk_values._rank(
9282-
axis=0,
9284+
axis=axis_int,
92839285
method=method,
92849286
ascending=ascending,
92859287
na_option=na_option,
92869288
pct=pct,
92879289
)
92889290
else:
9289-
if axis_int == 0:
9290-
ranks = algos.rank(
9291-
blk_values.T,
9292-
axis=axis_int,
9293-
method=method,
9294-
ascending=ascending,
9295-
na_option=na_option,
9296-
pct=pct,
9297-
).T
9298-
else:
9299-
ranks = algos.rank(
9300-
blk_values,
9301-
axis=axis_int,
9302-
method=method,
9303-
ascending=ascending,
9304-
na_option=na_option,
9305-
pct=pct,
9306-
)
9291+
ranks = algos.rank(
9292+
blk_values,
9293+
axis=axis_int,
9294+
method=method,
9295+
ascending=ascending,
9296+
na_option=na_option,
9297+
pct=pct,
9298+
)
9299+
if axis_int == 0:
9300+
ranks = ranks.T
93079301
return ranks
93089302

93099303
if numeric_only:

pandas/tests/frame/methods/test_rank.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
DataFrame,
1717
Index,
1818
Series,
19+
to_datetime,
20+
to_timedelta,
1921
)
2022
import pandas._testing as tm
2123

@@ -511,37 +513,84 @@ def test_rank_string_dtype(self, string_dtype_no_object):
511513
"double[pyarrow]",
512514
marks=td.skip_if_no("pyarrow"),
513515
),
514-
("min", "Int32", "Float64"),
515-
("min", "Float32", "Float64"),
516+
("min", "Int32", "UInt64"),
517+
("min", "Float32", "UInt64"),
516518
pytest.param(
517519
"min",
518520
"int32[pyarrow]",
519-
"double[pyarrow]",
521+
"uint64[pyarrow]",
520522
marks=td.skip_if_no("pyarrow"),
521523
),
522524
],
523525
)
524526
def test_rank_extension_array_dtype(self, method, og_dtype, expected_dtype):
525527
# GH#52829
526-
result = DataFrame([4, 89, 33], dtype=og_dtype).rank()
528+
result = DataFrame([4, 89, 33], dtype=og_dtype).rank(method=method)
527529
if method == "average":
528530
expected = DataFrame([1.0, 3.0, 2.0], dtype=expected_dtype)
529531
else:
530532
expected = DataFrame([1, 3, 2], dtype=expected_dtype)
531533
tm.assert_frame_equal(result, expected)
532534

533535
def test_rank_mixed_extension_array_dtype(self):
536+
# GH#52829
534537
pytest.importorskip("pyarrow")
535538
result = DataFrame(
536539
{
537540
"base": Series([4, 5, 6]),
538-
"extension": Series([7, 8, 9], dtype="int32[pyarrow]"),
541+
"pyarrow": Series([7, 8, 9], dtype="int32[pyarrow]"),
539542
}
540543
).rank(method="min")
541544
expected = DataFrame(
542545
{
543546
"base": Series([1.0, 2.0, 3.0], dtype="float64"),
544-
"extension": Series([1, 2, 3], dtype="uint64[pyarrow]"),
547+
"pyarrow": Series([1, 2, 3], dtype="uint64[pyarrow]"),
545548
}
546549
)
547550
tm.assert_frame_equal(result, expected)
551+
552+
def test_2d_extension_array_datetime(self):
553+
# GH#52829
554+
df = DataFrame(
555+
{
556+
"year": to_datetime(["2012-1-1", "2013-1-1", "2014-1-1"]),
557+
"week": to_datetime(["2012-1-2", "2012-1-9", "2012-1-16"]),
558+
"day": to_datetime(["2012-1-3", "2012-1-4", "2012-1-5"]),
559+
}
560+
)
561+
axis0_expected = DataFrame(
562+
{"year": [1.0, 2.0, 3.0], "week": [1.0, 2.0, 3.0], "day": [1.0, 2.0, 3.0]}
563+
)
564+
axis1_expected = DataFrame(
565+
{"year": [1.0, 3.0, 3.0], "week": [2.0, 2.0, 2.0], "day": [3.0, 1.0, 1.0]}
566+
)
567+
tm.assert_frame_equal(df.rank(), axis0_expected)
568+
tm.assert_frame_equal(df.rank(1), axis1_expected)
569+
570+
def test_2d_extension_array_timedelta(self):
571+
# GH#52829
572+
df = DataFrame(
573+
{
574+
"day": to_timedelta(["0 days", "1 day", "2 days"]),
575+
"hourly": to_timedelta(["23 hours", "24 hours", "25 hours"]),
576+
"minute": to_timedelta(
577+
["1439 minutes", "1440 minutes", "1441 minutes"]
578+
),
579+
}
580+
)
581+
axis0_expected = DataFrame(
582+
{
583+
"day": [1.0, 2.0, 3.0],
584+
"hourly": [1.0, 2.0, 3.0],
585+
"minute": [1.0, 2.0, 3.0],
586+
}
587+
)
588+
axis1_expected = DataFrame(
589+
{
590+
"day": [1.0, 2.0, 3.0],
591+
"hourly": [2.0, 2.0, 2.0],
592+
"minute": [3.0, 2.0, 1.0],
593+
}
594+
)
595+
tm.assert_frame_equal(df.rank(), axis0_expected)
596+
tm.assert_frame_equal(df.rank(1), axis1_expected)

0 commit comments

Comments
 (0)