Skip to content

Commit 98a70df

Browse files
committed
using _mgr apply with 2 failing tests
1 parent 188b2da commit 98a70df

File tree

4 files changed

+48
-18
lines changed

4 files changed

+48
-18
lines changed

pandas/core/algorithms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,6 @@ def rank(
10881088
)
10891089
else:
10901090
raise TypeError("Array with ndim > 2 are not supported.")
1091-
10921091
return ranks
10931092

10941093

pandas/core/arrays/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,7 @@ def _rank(
24082408
"""
24092409
See Series.rank.__doc__.
24102410
"""
2411+
24112412
if axis != 0:
24122413
raise NotImplementedError
24132414

pandas/core/generic.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9275,34 +9275,25 @@ def rank(
92759275
msg = "na_option must be one of 'keep', 'top', or 'bottom'"
92769276
raise ValueError(msg)
92779277

9278-
def ranker(data):
9279-
if data.ndim == 2:
9280-
# i.e. DataFrame, we cast to ndarray
9281-
values = data.values
9282-
else:
9283-
# i.e. Series, can dispatch to EA
9284-
values = data._values
9285-
9286-
if isinstance(values, ExtensionArray):
9287-
ranks = values._rank(
9288-
axis=axis_int,
9278+
def ranker(blk_values):
9279+
if isinstance(blk_values, ExtensionArray) and blk_values.ndim == 1:
9280+
ranks = blk_values._rank(
9281+
axis=0,
92899282
method=method,
92909283
ascending=ascending,
92919284
na_option=na_option,
92929285
pct=pct,
92939286
)
92949287
else:
92959288
ranks = algos.rank(
9296-
values,
9297-
axis=axis_int,
9289+
blk_values,
9290+
axis=1 - axis_int,
92989291
method=method,
92999292
ascending=ascending,
93009293
na_option=na_option,
93019294
pct=pct,
93029295
)
9303-
9304-
ranks_obj = self._constructor(ranks, **data._construct_axes_dict())
9305-
return ranks_obj.__finalize__(self, method="rank")
9296+
return ranks
93069297

93079298
if numeric_only:
93089299
if self.ndim == 1 and not is_numeric_dtype(self.dtype):
@@ -9315,7 +9306,10 @@ def ranker(data):
93159306
else:
93169307
data = self
93179308

9318-
return ranker(data)
9309+
result = data._mgr.apply(ranker)
9310+
return self._constructor_from_mgr(result, axes=result.axes).__finalize__(
9311+
self, method="rank"
9312+
)
93199313

93209314
@doc(_shared_docs["compare"], klass=_shared_doc_kwargs["klass"])
93219315
def compare(

pandas/tests/frame/methods/test_rank.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def test_rank_inf_and_nan(self, contents, dtype, frame_or_series):
405405
# Shuffle the testing array and expected results in the same way
406406
random_order = np.random.default_rng(2).permutation(len(values))
407407
obj = frame_or_series(values[random_order])
408+
print("TYPE", type(obj))
408409
expected = frame_or_series(exp_order[random_order], dtype="float64")
409410
result = obj.rank()
410411
tm.assert_equal(result, expected)
@@ -498,3 +499,38 @@ def test_rank_string_dtype(self, string_dtype_no_object):
498499
exp_dtype = "float64"
499500
expected = Series([1, 2, None, 3], dtype=exp_dtype)
500501
tm.assert_series_equal(result, expected)
502+
503+
@pytest.mark.parametrize(
504+
"method,og_dtype,expected_dtype",
505+
[
506+
("average", "UInt32", "Float64"),
507+
("average", "Float32", "Float64"),
508+
("average", "int32[pyarrow]", "double[pyarrow]"),
509+
("min", "Int32", "Float64"),
510+
("min", "Float32", "Float64"),
511+
("min", "int32[pyarrow]", "double[pyarrow]"),
512+
],
513+
)
514+
def test_rank_extension_array_dtype(self, method, og_dtype, expected_dtype):
515+
# GH#52829
516+
result = DataFrame([4, 89, 33], dtype=og_dtype).rank()
517+
if method == "average":
518+
expected = DataFrame([1.0, 3.0, 2.0], dtype=expected_dtype)
519+
else:
520+
expected = DataFrame([1, 3, 2], dtype=expected_dtype)
521+
tm.assert_frame_equal(result, expected)
522+
523+
def test_rank_mixed_extension_array_dtype(self):
524+
result = DataFrame(
525+
{
526+
"base": Series([4, 5, 6]),
527+
"extension": Series([7, 8, 9], dtype="int32[pyarrow]"),
528+
}
529+
).rank(method="min")
530+
expected = DataFrame(
531+
{
532+
"base": Series([1.0, 2.0, 3.0], dtype="float64"),
533+
"extension": Series([1, 2, 3], dtype="uint64[pyarrow]"),
534+
}
535+
)
536+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)