Skip to content

Commit 6eefa20

Browse files
author
Michele Pau
committed
added support for ordered categoricals in kendall and spearman correlation
1 parent cfd0d3f commit 6eefa20

File tree

5 files changed

+117
-19
lines changed

5 files changed

+117
-19
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Other enhancements
5656
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
5757
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
5858
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
59+
- :meth:`Series.corr`, :meth:`DataFrame.corr`, :meth:`DataFrame.corrwith` with ``method="kendall"`` and ``method="spearman"`` now work with ordered categorical data types (:issue:`60306`)
5960
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)
6061
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
6162
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)

pandas/core/frame.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11034,6 +11034,10 @@ def corr(
1103411034
data = self._get_numeric_data() if numeric_only else self
1103511035
cols = data.columns
1103611036
idx = cols.copy()
11037+
11038+
if method in ("spearman", "kendall"):
11039+
data = data._convert_ordered_cat_to_code()
11040+
1103711041
mat = data.to_numpy(dtype=float, na_value=np.nan, copy=False)
1103811042

1103911043
if method == "pearson":
@@ -11321,6 +11325,8 @@ def corrwith(
1132111325
correl = num / dom
1132211326

1132311327
elif method in ["kendall", "spearman"] or callable(method):
11328+
left = left._convert_ordered_cat_to_code()
11329+
right = right._convert_ordered_cat_to_code()
1132411330

1132511331
def c(x):
1132611332
return nanops.nancorr(x[0], x[1], method=method)
@@ -11352,6 +11358,24 @@ def c(x):
1135211358

1135311359
return correl
1135411360

11361+
def _convert_ordered_cat_to_code(self) -> DataFrame:
11362+
"""
11363+
Converts all category columns to their codes wherever possible
11364+
(i.e. wherever they are ordered) otherwise leaves shape unchanged
11365+
"""
11366+
categ = self.select_dtypes("category")
11367+
if len(categ.columns) == 0:
11368+
return self
11369+
11370+
cols_convert = categ.loc[:, categ.agg(lambda x: x.cat.ordered)].columns
11371+
if len(cols_convert) > 0:
11372+
data = self.copy(deep=False)
11373+
data[cols_convert] = data[cols_convert].transform(
11374+
lambda x: x.cat.codes.replace(-1, np.nan)
11375+
)
11376+
11377+
return data
11378+
1135511379
# ----------------------------------------------------------------------
1135611380
# ndarray-like stats methods
1135711381

pandas/core/series.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,6 +2680,12 @@ def corr(
26802680
if len(this) == 0:
26812681
return np.nan
26822682

2683+
if method in ("spearman", "kendall"):
2684+
if this.dtype == "category" and this.cat.ordered:
2685+
this = this.cat.codes.replace(-1, np.nan)
2686+
if other.dtype == "category" and other.cat.ordered:
2687+
other = other.cat.codes.replace(-1, np.nan)
2688+
26832689
this_values = this.to_numpy(dtype=float, na_value=np.nan, copy=False)
26842690
other_values = other.to_numpy(dtype=float, na_value=np.nan, copy=False)
26852691

pandas/tests/frame/methods/test_cov_corr.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pandas as pd
99
from pandas import (
10+
Categorical,
1011
DataFrame,
1112
Index,
1213
Series,
@@ -16,6 +17,19 @@
1617
import pandas._testing as tm
1718

1819

20+
@pytest.fixture
21+
def categorical_frame():
22+
frame = DataFrame(
23+
{
24+
"A": Categorical(list("abcde") * 6, list("bacde"), ordered=True),
25+
"B": Categorical(list("123") * 10, list("321"), ordered=True),
26+
}
27+
)
28+
frame.loc[frame.index[:5], "A"] = np.nan
29+
frame.loc[frame.index[3:6], "B"] = np.nan
30+
return frame
31+
32+
1933
class TestDataFrameCov:
2034
def test_cov(self, float_frame, float_string_frame):
2135
# min_periods no NAs (corner case)
@@ -116,6 +130,13 @@ def test_corr_scipy_method(self, float_frame, method):
116130
expected = float_frame["A"].corr(float_frame["C"], method=method)
117131
tm.assert_almost_equal(correls["A"]["C"], expected)
118132

133+
@pytest.mark.parametrize("method", ["kendall", "spearman"])
134+
def test_corr_scipy_method_category(self, method, categorical_frame):
135+
pytest.importorskip("scipy")
136+
correls = categorical_frame.corr(method=method)
137+
expected = categorical_frame["A"].corr(categorical_frame["B"], method=method)
138+
tm.assert_almost_equal(correls["A"]["B"], expected)
139+
119140
# ---------------------------------------------------------------------
120141

121142
def test_corr_non_numeric(self, float_string_frame):
@@ -303,6 +324,14 @@ def test_corrwith(self, datetime_frame, dtype):
303324
dropped = a.corrwith(b, axis=1, drop=True)
304325
assert a.index[-1] not in dropped.index
305326

327+
@pytest.mark.parametrize("method", ["spearman", "kendall"])
328+
def test_corrwith_categorical(self, categorical_frame, method):
329+
other = categorical_frame["B"]
330+
result = categorical_frame.corrwith(other, method=method)
331+
expected = categorical_frame.agg(lambda x: x.corr(other, method=method))
332+
tm.assert_almost_equal(result["A"], expected["A"])
333+
tm.assert_almost_equal(result["B"], expected["B"])
334+
306335
def test_corrwith_non_timeseries_data(self):
307336
index = ["a", "b", "c", "d", "e"]
308337
columns = ["one", "two", "three", "four"]

pandas/tests/series/methods/test_cov_corr.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,43 @@
55

66
import pandas as pd
77
from pandas import (
8+
Categorical,
89
Series,
910
date_range,
1011
isna,
1112
)
1213
import pandas._testing as tm
1314

1415

16+
@pytest.fixture
17+
def A():
18+
return Series(
19+
np.concatenate([np.arange(5, dtype=np.float64)] * 2),
20+
index=date_range("2020-01-01", periods=10),
21+
name="ts",
22+
)
23+
24+
25+
@pytest.fixture
26+
def B():
27+
return Series(
28+
np.arange(10, dtype=np.float64),
29+
index=date_range("2020-01-01", periods=10),
30+
name="ts",
31+
)
32+
33+
34+
@pytest.fixture
35+
def C():
36+
s = Series(
37+
data=Categorical(list("12345") * 2, categories=list("54321"), ordered=True),
38+
index=date_range("2020-01-01", periods=10),
39+
name="categorical",
40+
)
41+
s["2020-01-03"] = np.nan
42+
return s
43+
44+
1545
class TestSeriesCov:
1646
def test_cov(self, datetime_series):
1747
# full overlap
@@ -56,7 +86,7 @@ def test_cov_ddof(self, test_ddof, dtype):
5686

5787

5888
class TestSeriesCorr:
59-
def test_corr(self, datetime_series, any_float_dtype):
89+
def test_corr(self, B, datetime_series, any_float_dtype):
6090
stats = pytest.importorskip("scipy.stats")
6191

6292
datetime_series = datetime_series.astype(any_float_dtype)
@@ -81,29 +111,14 @@ def test_corr(self, datetime_series, any_float_dtype):
81111
cp[:] = np.nan
82112
assert isna(cp.corr(cp))
83113

84-
A = Series(
85-
np.arange(10, dtype=np.float64),
86-
index=date_range("2020-01-01", periods=10),
87-
name="ts",
88-
)
89-
result = A.corr(A)
90-
expected, _ = stats.pearsonr(A, A)
114+
result = B.corr(B)
115+
expected, _ = stats.pearsonr(B, B)
91116
tm.assert_almost_equal(result, expected)
92117

93-
def test_corr_rank(self):
118+
def test_corr_rank(self, A, B):
94119
stats = pytest.importorskip("scipy.stats")
95120

96121
# kendall and spearman
97-
B = Series(
98-
np.arange(10, dtype=np.float64),
99-
index=date_range("2020-01-01", periods=10),
100-
name="ts",
101-
)
102-
A = Series(
103-
np.concatenate([np.arange(5, dtype=np.float64)] * 2),
104-
index=date_range("2020-01-01", periods=10),
105-
name="ts",
106-
)
107122
result = A.corr(B, method="kendall")
108123
expected = stats.kendalltau(A, B)[0]
109124
tm.assert_almost_equal(result, expected)
@@ -146,6 +161,29 @@ def test_corr_rank(self):
146161
tm.assert_almost_equal(A.corr(B, method="kendall"), kexp)
147162
tm.assert_almost_equal(A.corr(B, method="spearman"), sexp)
148163

164+
def test_corr_category(self, A, C):
165+
stats = pytest.importorskip("scipy.stats")
166+
167+
def get_codes(s: Series) -> Series:
168+
return C.cat.codes.replace(-1, np.nan)
169+
170+
result = A.corr(C, method="pearson")
171+
expected = stats.pearsonr(A[C.notna()], C.dropna().astype("float"))[0]
172+
tm.assert_almost_equal(result, expected)
173+
tm.assert_almost_equal(result, 1)
174+
175+
result = A.corr(C, method="spearman")
176+
expected = stats.spearmanr(A, get_codes(C), nan_policy="omit")[0]
177+
expected_pearson = stats.pearsonr(A[C.notna()], get_codes(C).dropna())[0]
178+
179+
tm.assert_almost_equal(result, expected)
180+
tm.assert_almost_equal(result, expected_pearson)
181+
tm.assert_almost_equal(result, -1)
182+
183+
result = A.corr(C, method="kendall")
184+
expected = stats.kendalltau(A, get_codes(C), nan_policy="omit")[0]
185+
tm.assert_almost_equal(result, expected)
186+
149187
def test_corr_invalid_method(self):
150188
# GH PR #22298
151189
s1 = Series(np.random.default_rng(2).standard_normal(10))

0 commit comments

Comments
 (0)