Skip to content

Commit c2e334f

Browse files
authored
Fix incorrect result for df.sort_values when specifying multiple ascending (#2984)
1 parent a49cf51 commit c2e334f

File tree

3 files changed

+106
-162
lines changed

3 files changed

+106
-162
lines changed

mars/dataframe/sort/psrs.py

Lines changed: 65 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
from ... import opcodes as OperandDef
2121
from ...core.operand import OperandStage, MapReduceOperand
2222
from ...utils import lazy_import, calc_nsplits
23-
from ...serialization.serializables import Int32Field, ListField, StringField, BoolField
23+
from ...serialization.serializables import (
24+
AnyField,
25+
Int32Field,
26+
ListField,
27+
StringField,
28+
BoolField,
29+
)
2430
from ...tensor.base.psrs import PSRSOperandMixin
2531
from ..core import IndexValue, OutputType
2632
from ..utils import standardize_range_index, parse_index, is_cudf
@@ -48,6 +54,23 @@ def __gt__(self, other):
4854
_largest = _Largest()
4955

5056

57+
class _ReversedValue:
58+
def __init__(self, value):
59+
self._value = value
60+
61+
def __lt__(self, other):
62+
if type(other) is _ReversedValue:
63+
# may happen when call searchsorted
64+
return self._value >= other._value
65+
return self._value >= other
66+
67+
def __gt__(self, other):
68+
return self._value <= other
69+
70+
def __repr__(self):
71+
return repr(self._value)
72+
73+
5174
class DataFramePSRSOperandMixin(DataFrameOperandMixin, PSRSOperandMixin):
5275
@classmethod
5376
def _collect_op_properties(cls, op):
@@ -380,90 +403,23 @@ def execute_sort_index(data, op, inplace=None):
380403

381404
class DataFramePSRSChunkOperand(DataFrameOperand):
382405
# sort type could be 'sort_values' or 'sort_index'
383-
_sort_type = StringField("sort_type")
406+
sort_type = StringField("sort_type")
384407

385-
_axis = Int32Field("axis")
386-
_by = ListField("by")
387-
_ascending = BoolField("ascending")
388-
_inplace = BoolField("inplace")
389-
_kind = StringField("kind")
390-
_na_position = StringField("na_position")
408+
axis = Int32Field("axis")
409+
by = ListField("by", default=None)
410+
ascending = AnyField("ascending")
411+
inplace = BoolField("inplace")
412+
kind = StringField("kind")
413+
na_position = StringField("na_position")
391414

392415
# for sort_index
393-
_level = ListField("level")
394-
_sort_remaining = BoolField("sort_remaining")
395-
396-
_n_partition = Int32Field("n_partition")
397-
398-
def __init__(
399-
self,
400-
sort_type=None,
401-
by=None,
402-
axis=None,
403-
ascending=None,
404-
inplace=None,
405-
kind=None,
406-
na_position=None,
407-
level=None,
408-
sort_remaining=None,
409-
n_partition=None,
410-
output_types=None,
411-
**kw
412-
):
413-
super().__init__(
414-
_sort_type=sort_type,
415-
_by=by,
416-
_axis=axis,
417-
_ascending=ascending,
418-
_inplace=inplace,
419-
_kind=kind,
420-
_na_position=na_position,
421-
_level=level,
422-
_sort_remaining=sort_remaining,
423-
_n_partition=n_partition,
424-
_output_types=output_types,
425-
**kw
426-
)
416+
level = ListField("level")
417+
sort_remaining = BoolField("sort_remaining")
427418

428-
@property
429-
def sort_type(self):
430-
return self._sort_type
419+
n_partition = Int32Field("n_partition")
431420

432-
@property
433-
def axis(self):
434-
return self._axis
435-
436-
@property
437-
def by(self):
438-
return self._by
439-
440-
@property
441-
def ascending(self):
442-
return self._ascending
443-
444-
@property
445-
def inplace(self):
446-
return self._inplace
447-
448-
@property
449-
def kind(self):
450-
return self._kind
451-
452-
@property
453-
def na_position(self):
454-
return self._na_position
455-
456-
@property
457-
def level(self):
458-
return self._level
459-
460-
@property
461-
def sort_remaining(self):
462-
return self._sort_remaining
463-
464-
@property
465-
def n_partition(self):
466-
return self._n_partition
421+
def __init__(self, output_types=None, **kw):
422+
super().__init__(_output_types=output_types, **kw)
467423

468424

469425
class DataFramePSRSSortRegularSample(DataFramePSRSChunkOperand, DataFrameOperandMixin):
@@ -567,99 +523,49 @@ def execute(cls, ctx, op):
567523
class DataFramePSRSShuffle(MapReduceOperand, DataFrameOperandMixin):
568524
_op_type_ = OperandDef.PSRS_SHUFFLE
569525

570-
_sort_type = StringField("sort_type")
526+
sort_type = StringField("sort_type")
571527

572528
# for shuffle map
573-
_axis = Int32Field("axis")
574-
_by = ListField("by")
575-
_ascending = BoolField("ascending")
576-
_inplace = BoolField("inplace")
577-
_na_position = StringField("na_position")
578-
_n_partition = Int32Field("n_partition")
529+
axis = Int32Field("axis")
530+
by = ListField("by")
531+
ascending = AnyField("ascending")
532+
inplace = BoolField("inplace")
533+
na_position = StringField("na_position")
534+
n_partition = Int32Field("n_partition")
579535

580536
# for sort_index
581-
_level = ListField("level")
582-
_sort_remaining = BoolField("sort_remaining")
537+
level = ListField("level")
538+
sort_remaining = BoolField("sort_remaining")
583539

584540
# for shuffle reduce
585-
_kind = StringField("kind")
586-
587-
def __init__(
588-
self,
589-
sort_type=None,
590-
by=None,
591-
axis=None,
592-
ascending=None,
593-
n_partition=None,
594-
na_position=None,
595-
inplace=None,
596-
kind=None,
597-
level=None,
598-
sort_remaining=None,
599-
output_types=None,
600-
**kw
601-
):
602-
super().__init__(
603-
_sort_type=sort_type,
604-
_by=by,
605-
_axis=axis,
606-
_ascending=ascending,
607-
_n_partition=n_partition,
608-
_na_position=na_position,
609-
_inplace=inplace,
610-
_kind=kind,
611-
_level=level,
612-
_sort_remaining=sort_remaining,
613-
_output_types=output_types,
614-
**kw
615-
)
616-
617-
@property
618-
def sort_type(self):
619-
return self._sort_type
620-
621-
@property
622-
def by(self):
623-
return self._by
624-
625-
@property
626-
def axis(self):
627-
return self._axis
628-
629-
@property
630-
def ascending(self):
631-
return self._ascending
541+
kind = StringField("kind")
632542

633-
@property
634-
def inplace(self):
635-
return self._inplace
636-
637-
@property
638-
def na_position(self):
639-
return self._na_position
640-
641-
@property
642-
def level(self):
643-
return self._level
644-
645-
@property
646-
def sort_remaining(self):
647-
return self._sort_remaining
648-
649-
@property
650-
def n_partition(self):
651-
return self._n_partition
652-
653-
@property
654-
def kind(self):
655-
return self._kind
543+
def __init__(self, output_types=None, **kw):
544+
super().__init__(_output_types=output_types, **kw)
656545

657546
@property
658547
def output_limit(self):
659548
return 1
660549

661550
@staticmethod
662551
def _calc_poses(src_cols, pivots, ascending=True):
552+
if isinstance(ascending, list):
553+
for asc, col in zip(ascending, pivots.columns):
554+
# Make pivots available to use ascending order when mixed order specified
555+
if not asc:
556+
if pd.api.types.is_numeric_dtype(pivots.dtypes[col]):
557+
# for numeric dtypes, convert to negative is more efficient
558+
pivots[col] = -pivots[col]
559+
src_cols[col] = -src_cols[col]
560+
else:
561+
# for other types, convert to ReversedValue
562+
pivots[col] = pivots[col].map(
563+
lambda x: x
564+
if type(x) is _ReversedValue
565+
else _ReversedValue(x)
566+
)
567+
ascending = True
568+
663569
records = src_cols.to_records(index=False)
664570
p_records = pivots.to_records(index=False)
665571
if ascending:

mars/dataframe/sort/sort_values.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ def dataframe_sort_values(
252252
raise NotImplementedError("Only support sort on axis 0")
253253
psrs_kinds = _validate_sort_psrs_kinds(psrs_kinds)
254254
by = by if isinstance(by, (list, tuple)) else [by]
255+
if isinstance(ascending, list): # pragma: no cover
256+
if all(ascending):
257+
# all are True, convert to True
258+
ascending = True
259+
elif not any(ascending):
260+
# all are False, convert to False
261+
ascending = False
255262
op = DataFrameSortValues(
256263
by=by,
257264
axis=axis,

mars/dataframe/sort/tests/test_sort_execution.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
"distinct_opt", ["0"] if sys.platform.lower().startswith("win") else ["0", "1"]
2828
)
2929
def test_sort_values_execution(setup, distinct_opt):
30+
ns = np.random.RandomState(0)
3031
os.environ["PSRS_DISTINCT_COL"] = distinct_opt
31-
df = pd.DataFrame(
32-
np.random.rand(100, 10), columns=["a" + str(i) for i in range(10)]
33-
)
32+
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])
3433

3534
# test one chunk
3635
mdf = DataFrame(df)
@@ -67,6 +66,38 @@ def test_sort_values_execution(setup, distinct_opt):
6766

6867
pd.testing.assert_frame_equal(result, expected)
6968

69+
# test ascending is a list
70+
result = (
71+
mdf.sort_values(["a3", "a4", "a5", "a6"], ascending=[False, True, True, False])
72+
.execute()
73+
.fetch()
74+
)
75+
expected = df.sort_values(
76+
["a3", "a4", "a5", "a6"], ascending=[False, True, True, False]
77+
)
78+
pd.testing.assert_frame_equal(result, expected)
79+
80+
in_df = pd.DataFrame(
81+
{
82+
"col1": ns.choice([f"a{i}" for i in range(5)], size=(100,)),
83+
"col2": ns.choice([f"b{i}" for i in range(5)], size=(100,)),
84+
"col3": ns.choice([f"c{i}" for i in range(5)], size=(100,)),
85+
"col4": ns.randint(10, 20, size=(100,)),
86+
}
87+
)
88+
mdf = DataFrame(in_df, chunk_size=10)
89+
result = (
90+
mdf.sort_values(
91+
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
92+
)
93+
.execute()
94+
.fetch()
95+
)
96+
expected = in_df.sort_values(
97+
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
98+
)
99+
pd.testing.assert_frame_equal(result, expected)
100+
70101
# test multiindex
71102
df2 = df.copy(deep=True)
72103
df2.columns = pd.MultiIndex.from_product([list("AB"), list("CDEFG")])

0 commit comments

Comments
 (0)