Skip to content

Commit 5432f2a

Browse files
BUG/API (string dtype): return float dtype for series[str].rank()
1 parent 47b56ea commit 5432f2a

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

pandas/core/arrays/string_arrow.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
3232
from pandas.core.arrays.arrow import ArrowExtensionArray
3333
from pandas.core.arrays.boolean import BooleanDtype
34+
from pandas.core.arrays.floating import Float64Dtype
3435
from pandas.core.arrays.integer import Int64Dtype
3536
from pandas.core.arrays.numeric import NumericDtype
3637
from pandas.core.arrays.string_ import (
@@ -444,6 +445,16 @@ def _convert_int_result(self, result):
444445

445446
return Int64Dtype().__from_arrow__(result)
446447

448+
def _convert_float_result(self, result):
449+
if self.dtype.na_value is np.nan:
450+
if isinstance(result, pa.Array):
451+
result = result.to_numpy(zero_copy_only=False)
452+
else:
453+
result = result.to_numpy()
454+
return result.astype("float64", copy=False)
455+
456+
return Float64Dtype().__from_arrow__(result)
457+
447458
def _reduce(
448459
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
449460
):
@@ -477,7 +488,7 @@ def _rank(
477488
"""
478489
See Series.rank.__doc__.
479490
"""
480-
return self._convert_int_result(
491+
return self._convert_float_result(
481492
self._rank_calc(
482493
axis=axis,
483494
method=method,

pandas/tests/series/methods/test_rank.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def ser():
3333
["max", np.array([2, 6, 7, 4, np.nan, 4, 2, 8, np.nan, 6])],
3434
["first", np.array([1, 5, 7, 3, np.nan, 4, 2, 8, np.nan, 6])],
3535
["dense", np.array([1, 3, 4, 2, np.nan, 2, 1, 5, np.nan, 3])],
36-
]
36+
],
37+
ids=lambda x: x[0],
3738
)
3839
def results(request):
3940
return request.param
@@ -48,12 +49,29 @@ def results(request):
4849
"Int64",
4950
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
5051
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
52+
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
53+
"string[python]",
54+
"str",
5155
]
5256
)
5357
def dtype(request):
5458
return request.param
5559

5660

61+
def expected_dtype(dtype, method, pct=False):
62+
exp_dtype = "float64"
63+
# elif dtype in ["Int64", "Float64", "string[pyarrow]", "string[python]"]:
64+
if dtype in ["string[pyarrow]"]:
65+
exp_dtype = "Float64"
66+
elif dtype in ["float64[pyarrow]", "int64[pyarrow]"]:
67+
if method == "average" or pct:
68+
exp_dtype = "double[pyarrow]"
69+
else:
70+
exp_dtype = "uint64[pyarrow]"
71+
72+
return exp_dtype
73+
74+
5775
class TestSeriesRank:
5876
def test_rank(self, datetime_series):
5977
sp_stats = pytest.importorskip("scipy.stats")
@@ -251,12 +269,14 @@ def test_rank_signature(self):
251269
with pytest.raises(ValueError, match=msg):
252270
s.rank("average")
253271

254-
@pytest.mark.parametrize("dtype", [None, object])
255-
def test_rank_tie_methods(self, ser, results, dtype):
272+
def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
256273
method, exp = results
274+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
275+
pytest.skip("int64/str does not support NaN")
276+
257277
ser = ser if dtype is None else ser.astype(dtype)
258278
result = ser.rank(method=method)
259-
tm.assert_series_equal(result, Series(exp))
279+
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
260280

261281
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
262282
@pytest.mark.parametrize(
@@ -357,25 +377,35 @@ def test_rank_methods_series(self, rank_method, op, value):
357377
],
358378
)
359379
def test_rank_dense_method(self, dtype, ser, exp):
380+
if ser[0] < 0 and dtype.startswith("str"):
381+
exp = exp[::-1]
360382
s = Series(ser).astype(dtype)
361383
result = s.rank(method="dense")
362-
expected = Series(exp).astype(result.dtype)
384+
expected = Series(exp).astype(expected_dtype(dtype, "dense"))
363385
tm.assert_series_equal(result, expected)
364386

365-
def test_rank_descending(self, ser, results, dtype):
387+
def test_rank_descending(self, ser, results, dtype, using_infer_string):
366388
method, _ = results
367-
if "i" in dtype:
389+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
368390
s = ser.dropna()
369391
else:
370392
s = ser.astype(dtype)
371393

372394
res = s.rank(ascending=False)
373-
expected = (s.max() - s).rank()
374-
tm.assert_series_equal(res, expected)
395+
if dtype.startswith("str"):
396+
expected = (s.astype("float64").max() - s.astype("float64")).rank()
397+
else:
398+
expected = (s.max() - s).rank()
399+
tm.assert_series_equal(res, expected.astype(expected_dtype(dtype, "average")))
375400

376-
expected = (s.max() - s).rank(method=method)
401+
if dtype.startswith("str"):
402+
expected = (s.astype("float64").max() - s.astype("float64")).rank(
403+
method=method
404+
)
405+
else:
406+
expected = (s.max() - s).rank(method=method)
377407
res2 = s.rank(method=method, ascending=False)
378-
tm.assert_series_equal(res2, expected)
408+
tm.assert_series_equal(res2, expected.astype(expected_dtype(dtype, method)))
379409

380410
def test_rank_int(self, ser, results):
381411
method, exp = results
@@ -432,9 +462,11 @@ def test_rank_ea_small_values(self):
432462
],
433463
)
434464
def test_rank_dense_pct(dtype, ser, exp):
465+
if ser[0] < 0 and dtype.startswith("str"):
466+
exp = exp[::-1]
435467
s = Series(ser).astype(dtype)
436468
result = s.rank(method="dense", pct=True)
437-
expected = Series(exp).astype(result.dtype)
469+
expected = Series(exp).astype(expected_dtype(dtype, "dense", pct=True))
438470
tm.assert_series_equal(result, expected)
439471

440472

@@ -453,9 +485,11 @@ def test_rank_dense_pct(dtype, ser, exp):
453485
],
454486
)
455487
def test_rank_min_pct(dtype, ser, exp):
488+
if ser[0] < 0 and dtype.startswith("str"):
489+
exp = exp[::-1]
456490
s = Series(ser).astype(dtype)
457491
result = s.rank(method="min", pct=True)
458-
expected = Series(exp).astype(result.dtype)
492+
expected = Series(exp).astype(expected_dtype(dtype, "min", pct=True))
459493
tm.assert_series_equal(result, expected)
460494

461495

@@ -474,9 +508,11 @@ def test_rank_min_pct(dtype, ser, exp):
474508
],
475509
)
476510
def test_rank_max_pct(dtype, ser, exp):
511+
if ser[0] < 0 and dtype.startswith("str"):
512+
exp = exp[::-1]
477513
s = Series(ser).astype(dtype)
478514
result = s.rank(method="max", pct=True)
479-
expected = Series(exp).astype(result.dtype)
515+
expected = Series(exp).astype(expected_dtype(dtype, "max", pct=True))
480516
tm.assert_series_equal(result, expected)
481517

482518

@@ -495,9 +531,11 @@ def test_rank_max_pct(dtype, ser, exp):
495531
],
496532
)
497533
def test_rank_average_pct(dtype, ser, exp):
534+
if ser[0] < 0 and dtype.startswith("str"):
535+
exp = exp[::-1]
498536
s = Series(ser).astype(dtype)
499537
result = s.rank(method="average", pct=True)
500-
expected = Series(exp).astype(result.dtype)
538+
expected = Series(exp).astype(expected_dtype(dtype, "average", pct=True))
501539
tm.assert_series_equal(result, expected)
502540

503541

@@ -516,9 +554,11 @@ def test_rank_average_pct(dtype, ser, exp):
516554
],
517555
)
518556
def test_rank_first_pct(dtype, ser, exp):
557+
if ser[0] < 0 and dtype.startswith("str"):
558+
exp = exp[::-1]
519559
s = Series(ser).astype(dtype)
520560
result = s.rank(method="first", pct=True)
521-
expected = Series(exp).astype(result.dtype)
561+
expected = Series(exp).astype(expected_dtype(dtype, "first", pct=True))
522562
tm.assert_series_equal(result, expected)
523563

524564

0 commit comments

Comments
 (0)