Skip to content

Commit 5e0b7e3

Browse files
sfc-gh-vrpatelsfc-gh-mvashishtha
authored andcommitted
FIX-#7551: Fix name ambiguity for value_counts() on Pandas backend (#7585)
Implements index renaming for `sort_rows_by_column_values()` in the Pandas backend, similarly as in the Ray backend, to fix the name ambiguity error in `value_counts()`. Signed-off-by: Vraj Patel <vraj.patel@snowflake.com>
1 parent 7e189a8 commit 5e0b7e3

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

modin/core/storage_formats/base/query_compiler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3079,9 +3079,21 @@ def sort_rows_by_column_values(
30793079
BaseQueryCompiler
30803080
New QueryCompiler that contains result of the sort.
30813081
"""
3082-
return DataFrameDefault.register(pandas.DataFrame.sort_values)(
3082+
# Avoid index/column name collisions by renaming and restoring after sorting
3083+
index_renaming = None
3084+
if is_scalar(columns):
3085+
columns = [columns]
3086+
if any(name in columns for name in self.index.names):
3087+
index_renaming = self.index.names
3088+
self.index = self.index.set_names([None] * len(self.index.names))
3089+
new_query_compiler = DataFrameDefault.register(pandas.DataFrame.sort_values)(
30833090
self, by=columns, axis=0, ascending=ascending, **kwargs
30843091
)
3092+
if index_renaming is not None:
3093+
new_query_compiler.index = new_query_compiler.index.set_names(
3094+
index_renaming
3095+
)
3096+
return new_query_compiler
30853097

30863098
# END Abstract map across rows/columns
30873099

modin/pandas/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3833,7 +3833,16 @@ def value_counts(
38333833
by=subset, dropna=dropna, observed=True, sort=False
38343834
).size()
38353835
if sort:
3836-
counted_values.sort_values(ascending=ascending, inplace=True)
3836+
if counted_values.name is None:
3837+
counted_values.name = 0
3838+
by = counted_values.name
3839+
result = counted_values._query_compiler.sort_rows_by_column_values(
3840+
columns=by,
3841+
ascending=ascending,
3842+
)
3843+
counted_values = self._create_or_update_from_compiler(result)
3844+
if isinstance(counted_values, pd.DataFrame):
3845+
counted_values = counted_values.squeeze(axis=1)
38373846
if normalize:
38383847
counted_values = counted_values / counted_values.sum()
38393848
# TODO: uncomment when strict compability mode will be implemented:

modin/tests/pandas/dataframe/test_join_sort.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,3 +1012,50 @@ def test_compare(align_axis, keep_shape, keep_equal):
10121012
modin_result = modin_series2.compare(modin_series1, **kwargs)
10131013
pandas_result = pandas_series2.compare(pandas_series1, **kwargs)
10141014
assert to_pandas(modin_result).equals(pandas_result)
1015+
1016+
1017+
@pytest.mark.parametrize(
1018+
"params",
1019+
[
1020+
{"ascending": True},
1021+
{"normalize": True},
1022+
pytest.param(
1023+
{"sort": False},
1024+
marks=(
1025+
pytest.mark.xfail(
1026+
reason="Known issue with sort=False in `groupby()` "
1027+
+ "(https://github.com/modin-project/modin/issues/3571)",
1028+
strict=True,
1029+
)
1030+
if Engine.get() in ("Python", "Ray", "Dask", "Unidist")
1031+
and StorageFormat.get() != "Base"
1032+
else []
1033+
),
1034+
),
1035+
],
1036+
)
1037+
def test_value_counts(params):
1038+
data = [[4, 1, 3, 2], [2, 5, 6, 5], [4, 3, 3, 5]]
1039+
columns = ["col1", "col2", "col3", "col4"]
1040+
1041+
eval_general(
1042+
*create_test_dfs(data, columns=columns),
1043+
lambda df: df["col1"].value_counts(**params),
1044+
)
1045+
1046+
1047+
def test_value_counts_with_nulls():
1048+
data = [[5, 6, None, 7, 7], [None, None, 5, 8]]
1049+
eval_general(*create_test_dfs(data), lambda df: df[0].value_counts(dropna=False))
1050+
1051+
1052+
def test_value_counts_with_multiindex():
1053+
data = [[1, 2, 2, 4]]
1054+
index = pd.MultiIndex.from_arrays(
1055+
arrays=[["a", "a", "b", "b"], [1, 2, 1, 2]], names=("l1", "l2")
1056+
)
1057+
1058+
eval_general(
1059+
*create_test_dfs(data, index=index),
1060+
lambda df: df[0].value_counts(),
1061+
)

0 commit comments

Comments
 (0)