|
20 | 20 | from ... import opcodes as OperandDef |
21 | 21 | from ...core.operand import OperandStage, MapReduceOperand |
22 | 22 | from ...utils import lazy_import, calc_nsplits |
23 | | -from ...serialization.serializables import Int32Field, ListField, StringField, BoolField |
| 23 | +from ...serialization.serializables import ( |
| 24 | + AnyField, |
| 25 | + Int32Field, |
| 26 | + ListField, |
| 27 | + StringField, |
| 28 | + BoolField, |
| 29 | +) |
24 | 30 | from ...tensor.base.psrs import PSRSOperandMixin |
25 | 31 | from ..core import IndexValue, OutputType |
26 | 32 | from ..utils import standardize_range_index, parse_index, is_cudf |
@@ -48,6 +54,23 @@ def __gt__(self, other): |
48 | 54 | _largest = _Largest() |
49 | 55 |
|
50 | 56 |
|
| 57 | +class _ReversedValue: |
| 58 | + def __init__(self, value): |
| 59 | + self._value = value |
| 60 | + |
| 61 | + def __lt__(self, other): |
| 62 | + if type(other) is _ReversedValue: |
| 63 | + # may happen when call searchsorted |
| 64 | + return self._value >= other._value |
| 65 | + return self._value >= other |
| 66 | + |
| 67 | + def __gt__(self, other): |
| 68 | + return self._value <= other |
| 69 | + |
| 70 | + def __repr__(self): |
| 71 | + return repr(self._value) |
| 72 | + |
| 73 | + |
51 | 74 | class DataFramePSRSOperandMixin(DataFrameOperandMixin, PSRSOperandMixin): |
52 | 75 | @classmethod |
53 | 76 | def _collect_op_properties(cls, op): |
@@ -380,90 +403,23 @@ def execute_sort_index(data, op, inplace=None): |
380 | 403 |
|
381 | 404 | class DataFramePSRSChunkOperand(DataFrameOperand): |
382 | 405 | # sort type could be 'sort_values' or 'sort_index' |
383 | | - _sort_type = StringField("sort_type") |
| 406 | + sort_type = StringField("sort_type") |
384 | 407 |
|
385 | | - _axis = Int32Field("axis") |
386 | | - _by = ListField("by") |
387 | | - _ascending = BoolField("ascending") |
388 | | - _inplace = BoolField("inplace") |
389 | | - _kind = StringField("kind") |
390 | | - _na_position = StringField("na_position") |
| 408 | + axis = Int32Field("axis") |
| 409 | + by = ListField("by", default=None) |
| 410 | + ascending = AnyField("ascending") |
| 411 | + inplace = BoolField("inplace") |
| 412 | + kind = StringField("kind") |
| 413 | + na_position = StringField("na_position") |
391 | 414 |
|
392 | 415 | # for sort_index |
393 | | - _level = ListField("level") |
394 | | - _sort_remaining = BoolField("sort_remaining") |
395 | | - |
396 | | - _n_partition = Int32Field("n_partition") |
397 | | - |
398 | | - def __init__( |
399 | | - self, |
400 | | - sort_type=None, |
401 | | - by=None, |
402 | | - axis=None, |
403 | | - ascending=None, |
404 | | - inplace=None, |
405 | | - kind=None, |
406 | | - na_position=None, |
407 | | - level=None, |
408 | | - sort_remaining=None, |
409 | | - n_partition=None, |
410 | | - output_types=None, |
411 | | - **kw |
412 | | - ): |
413 | | - super().__init__( |
414 | | - _sort_type=sort_type, |
415 | | - _by=by, |
416 | | - _axis=axis, |
417 | | - _ascending=ascending, |
418 | | - _inplace=inplace, |
419 | | - _kind=kind, |
420 | | - _na_position=na_position, |
421 | | - _level=level, |
422 | | - _sort_remaining=sort_remaining, |
423 | | - _n_partition=n_partition, |
424 | | - _output_types=output_types, |
425 | | - **kw |
426 | | - ) |
| 416 | + level = ListField("level") |
| 417 | + sort_remaining = BoolField("sort_remaining") |
427 | 418 |
|
428 | | - @property |
429 | | - def sort_type(self): |
430 | | - return self._sort_type |
| 419 | + n_partition = Int32Field("n_partition") |
431 | 420 |
|
432 | | - @property |
433 | | - def axis(self): |
434 | | - return self._axis |
435 | | - |
436 | | - @property |
437 | | - def by(self): |
438 | | - return self._by |
439 | | - |
440 | | - @property |
441 | | - def ascending(self): |
442 | | - return self._ascending |
443 | | - |
444 | | - @property |
445 | | - def inplace(self): |
446 | | - return self._inplace |
447 | | - |
448 | | - @property |
449 | | - def kind(self): |
450 | | - return self._kind |
451 | | - |
452 | | - @property |
453 | | - def na_position(self): |
454 | | - return self._na_position |
455 | | - |
456 | | - @property |
457 | | - def level(self): |
458 | | - return self._level |
459 | | - |
460 | | - @property |
461 | | - def sort_remaining(self): |
462 | | - return self._sort_remaining |
463 | | - |
464 | | - @property |
465 | | - def n_partition(self): |
466 | | - return self._n_partition |
| 421 | + def __init__(self, output_types=None, **kw): |
| 422 | + super().__init__(_output_types=output_types, **kw) |
467 | 423 |
|
468 | 424 |
|
469 | 425 | class DataFramePSRSSortRegularSample(DataFramePSRSChunkOperand, DataFrameOperandMixin): |
@@ -567,99 +523,49 @@ def execute(cls, ctx, op): |
567 | 523 | class DataFramePSRSShuffle(MapReduceOperand, DataFrameOperandMixin): |
568 | 524 | _op_type_ = OperandDef.PSRS_SHUFFLE |
569 | 525 |
|
570 | | - _sort_type = StringField("sort_type") |
| 526 | + sort_type = StringField("sort_type") |
571 | 527 |
|
572 | 528 | # for shuffle map |
573 | | - _axis = Int32Field("axis") |
574 | | - _by = ListField("by") |
575 | | - _ascending = BoolField("ascending") |
576 | | - _inplace = BoolField("inplace") |
577 | | - _na_position = StringField("na_position") |
578 | | - _n_partition = Int32Field("n_partition") |
| 529 | + axis = Int32Field("axis") |
| 530 | + by = ListField("by") |
| 531 | + ascending = AnyField("ascending") |
| 532 | + inplace = BoolField("inplace") |
| 533 | + na_position = StringField("na_position") |
| 534 | + n_partition = Int32Field("n_partition") |
579 | 535 |
|
580 | 536 | # for sort_index |
581 | | - _level = ListField("level") |
582 | | - _sort_remaining = BoolField("sort_remaining") |
| 537 | + level = ListField("level") |
| 538 | + sort_remaining = BoolField("sort_remaining") |
583 | 539 |
|
584 | 540 | # for shuffle reduce |
585 | | - _kind = StringField("kind") |
586 | | - |
587 | | - def __init__( |
588 | | - self, |
589 | | - sort_type=None, |
590 | | - by=None, |
591 | | - axis=None, |
592 | | - ascending=None, |
593 | | - n_partition=None, |
594 | | - na_position=None, |
595 | | - inplace=None, |
596 | | - kind=None, |
597 | | - level=None, |
598 | | - sort_remaining=None, |
599 | | - output_types=None, |
600 | | - **kw |
601 | | - ): |
602 | | - super().__init__( |
603 | | - _sort_type=sort_type, |
604 | | - _by=by, |
605 | | - _axis=axis, |
606 | | - _ascending=ascending, |
607 | | - _n_partition=n_partition, |
608 | | - _na_position=na_position, |
609 | | - _inplace=inplace, |
610 | | - _kind=kind, |
611 | | - _level=level, |
612 | | - _sort_remaining=sort_remaining, |
613 | | - _output_types=output_types, |
614 | | - **kw |
615 | | - ) |
616 | | - |
617 | | - @property |
618 | | - def sort_type(self): |
619 | | - return self._sort_type |
620 | | - |
621 | | - @property |
622 | | - def by(self): |
623 | | - return self._by |
624 | | - |
625 | | - @property |
626 | | - def axis(self): |
627 | | - return self._axis |
628 | | - |
629 | | - @property |
630 | | - def ascending(self): |
631 | | - return self._ascending |
| 541 | + kind = StringField("kind") |
632 | 542 |
|
633 | | - @property |
634 | | - def inplace(self): |
635 | | - return self._inplace |
636 | | - |
637 | | - @property |
638 | | - def na_position(self): |
639 | | - return self._na_position |
640 | | - |
641 | | - @property |
642 | | - def level(self): |
643 | | - return self._level |
644 | | - |
645 | | - @property |
646 | | - def sort_remaining(self): |
647 | | - return self._sort_remaining |
648 | | - |
649 | | - @property |
650 | | - def n_partition(self): |
651 | | - return self._n_partition |
652 | | - |
653 | | - @property |
654 | | - def kind(self): |
655 | | - return self._kind |
| 543 | + def __init__(self, output_types=None, **kw): |
| 544 | + super().__init__(_output_types=output_types, **kw) |
656 | 545 |
|
657 | 546 | @property |
658 | 547 | def output_limit(self): |
659 | 548 | return 1 |
660 | 549 |
|
661 | 550 | @staticmethod |
662 | 551 | def _calc_poses(src_cols, pivots, ascending=True): |
| 552 | + if isinstance(ascending, list): |
| 553 | + for asc, col in zip(ascending, pivots.columns): |
| 554 | + # Make pivots available to use ascending order when mixed order specified |
| 555 | + if not asc: |
| 556 | + if pd.api.types.is_numeric_dtype(pivots.dtypes[col]): |
| 557 | + # for numeric dtypes, convert to negative is more efficient |
| 558 | + pivots[col] = -pivots[col] |
| 559 | + src_cols[col] = -src_cols[col] |
| 560 | + else: |
| 561 | + # for other types, convert to ReversedValue |
| 562 | + pivots[col] = pivots[col].map( |
| 563 | + lambda x: x |
| 564 | + if type(x) is _ReversedValue |
| 565 | + else _ReversedValue(x) |
| 566 | + ) |
| 567 | + ascending = True |
| 568 | + |
663 | 569 | records = src_cols.to_records(index=False) |
664 | 570 | p_records = pivots.to_records(index=False) |
665 | 571 | if ascending: |
|
0 commit comments