Skip to content

Commit 957517a

Browse files
authored
Make 'list::sort' quick by using natural mergesort (#711)
Please see issue #710 for context and analysis. (Resolves #710)
1 parent f4bf8a7 commit 957517a

File tree

3 files changed

+117
-18
lines changed

3 files changed

+117
-18
lines changed

examples/stdlib/list/sortBy.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Cons(-1, Cons(1, Cons(3, Cons(5, Nil()))))
2+
Cons(5, Cons(3, Cons(1, Cons(-1, Nil()))))
3+
Cons((-1, 1), Cons((0, 0), Cons((1, 0), Cons((0, 1), Nil()))))
4+
Nil()

examples/stdlib/list/sortBy.effekt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module examples/pos/list/sortBy
2+
3+
import list
4+
5+
def main() = {
6+
// synchronized with doctest in `sortBy`
7+
println([1, 3, -1, 5].sortBy { (a, b) => a <= b })
8+
println("Cons(5, Cons(3, Cons(1, Cons(-1, Nil()))))")
9+
println("Cons((-1, 1), Cons((0, 0), Cons((1, 0), Cons((0, 1), Nil()))))")
10+
println("Nil()")
11+
//println([1, 3, -1, 5].sortBy { (a, b) => a >= b })
12+
13+
//val sorted: List[(Int, Int)] = [(1, 0), (0, 1), (-1, 1), (0, 0)]
14+
// .sortBy { (a, b) => a.first + a.second <= b.first + b.second }
15+
//println(show(sorted.map { case (a, b) => "(" ++ show(a) ++ ", " ++ show(b) ++ ")" }))
16+
//println(Nil[Int]().sortBy { (a, b) => a <= b })
17+
}

libraries/common/list.effekt

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -670,35 +670,113 @@ def partition[A](l: List[A]) { pred: A => Bool }: (List[A], List[A]) = {
670670
(lefts.reverse, rights.reverse)
671671
}
672672

673-
/// Sort a list using a given comparison function.
673+
/// Utilities for sorting, see 'sortBy' for more details.
674+
namespace sort {
675+
/// Splits the given list into monotonic segments (so a list of lists).
676+
///
677+
/// Internally used in the mergesort 'sortBy' to prepare the to-be-merged partitions.
678+
def sequences[A](list: List[A]) { compare: (A, A) => Bool }: List[List[A]] = list match {
679+
case Cons(a, Cons(b, rest)) =>
680+
if (compare(a, b)) {
681+
ascending(b, rest) { diffRest => Cons(a, diffRest) } {compare}
682+
} else {
683+
descending(b, [a], rest) {compare}
684+
}
685+
case _ => [list]
686+
}
687+
688+
/// When in an ascending sequence, try to add `current` to `run` (if possible)
689+
def ascending[A](current: A, rest: List[A]) { runDiff: List[A] => List[A] } { compare: (A, A) => Bool }: List[List[A]] = rest match {
690+
case Cons(next, tail) and compare(current, next) =>
691+
ascending(next, tail) { diffRest => runDiff(Cons(current, diffRest)) } {compare}
692+
case _ => Cons(runDiff([current]), sequences(rest) {compare})
693+
}
694+
695+
/// When in an descending sequence, try to add `current` to `run` (if possible)
696+
def descending[A](current: A, run: List[A], rest: List[A]) { compare: (A, A) => Bool }: List[List[A]] = rest match {
697+
case Cons(next, tail) and not(compare(current, next)) =>
698+
descending(next, Cons(current, run), tail) {compare}
699+
case _ => Cons(Cons(current, run), sequences(rest) {compare})
700+
}
701+
702+
def mergeAll[A](runs: List[List[A]]) { compare: (A, A) => Bool }: List[A] = runs match {
703+
case Cons(single, Nil()) => single
704+
case _ => {
705+
// recursively merge in pairs until there's only a single list
706+
val newRuns = mergePairs(runs) {compare}
707+
mergeAll(newRuns) {compare}
708+
}
709+
}
710+
711+
def mergePairs[A](runs: List[List[A]]) { compare: (A, A) => Bool }: List[List[A]] = runs match {
712+
case Cons(a, Cons(b, rest)) =>
713+
Cons(merge(a, b) {compare}, mergePairs(rest) {compare})
714+
case _ => runs
715+
}
716+
717+
def merge[A](l1: List[A], l2: List[A]) { compare: (A, A) => Bool }: List[A] =
718+
(l1, l2) match {
719+
case (Nil(), _) => l2
720+
case (_, Nil()) => l1
721+
case (Cons(h1, t1), Cons(h2, t2)) =>
722+
if (compare(h1, h2)) {
723+
Cons(h1, merge(t1, l2) {compare})
724+
} else {
725+
Cons(h2, merge(l1, t2) {compare})
726+
}
727+
}
728+
}
729+
730+
/// Sort a list given a comparison operator (like less-or-equal!)
731+
/// The sorting algorithm is stable and should act reasonably well on partially sorted data.
732+
///
733+
/// Examples:
734+
/// ```
735+
/// > [1, 3, -1, 5].sortBy { (a, b) => a <= b }
736+
/// [-1, 1, 3, 5]
737+
///
738+
/// > [1, 3, -1, 5].sortBy { (a, b) => a >= b }
739+
/// [5, 3, 1, -1]
740+
///
741+
/// > [(1, 0), (0, 1), (-1, 1), (0, 0)].sortBy { (a, b) => a.first + a.second <= b.first + b.second }
742+
/// [(1, -1), (0, 0), (1, 0), (0, 1)]
743+
///
744+
/// > Nil[Int]().sortBy { (a, b) => a <= b }
745+
/// []
746+
/// ```
674747
///
675748
/// Note: this implementation is not stacksafe!
749+
/// (works for ~5M random elements just fine, but OOMs on ~10M random elements)
676750
///
677-
/// O(N log N)
678-
def sortBy[A](l: List[A]) { compare: (A, A) => Bool }: List[A] =
679-
l match {
680-
case Nil() => Nil()
681-
case Cons(pivot, rest) =>
682-
val (lt, gt) = rest.partition { el => compare(el, pivot) };
683-
val leftSorted = sortBy(lt) { (a, b) => compare(a, b) }
684-
val rightSorted = sortBy(gt) { (a, b) => compare(a, b) }
685-
leftSorted.append(Cons(pivot, rightSorted))
686-
}
751+
/// O(N log N) worstcase
752+
def sortBy[A](list: List[A]) { lessOrEqual: (A, A) => Bool }: List[A] = {
753+
val monotonicRuns = sort::sequences(list) {lessOrEqual}
754+
sort::mergeAll(monotonicRuns) {lessOrEqual}
755+
}
756+
757+
/// Sort a list of integers in an ascending order.
758+
/// See 'sortBy' for more details.
759+
///
760+
/// O(N log N) worstcase
761+
def sort(l: List[Int]): List[Int] = l.sortBy { (a, b) => a <= b }
762+
763+
/// Sort a list of doubles in an ascending order.
764+
/// See 'sortBy' for more details.
765+
///
766+
/// O(N log N) worstcase
767+
def sort(l: List[Double]): List[Double] = l.sortBy { (a, b) => a <= b }
687768

688-
def sort(l: List[Int]): List[Int] = l.sortBy { (a, b) => a < b }
689-
def sort(l: List[Double]): List[Double] = l.sortBy { (a, b) => a < b }
690769

691-
/// Check if a list is sorted according to the given comparison function.
770+
/// Check if a list is sorted according to the given comparison function (less-or-equal).
692771
///
693772
/// O(N)
694-
def isSortedBy[A](list: List[A]) { compare: (A, A) => Bool }: Bool = {
773+
def isSortedBy[A](list: List[A]) { lessOrEqual: (A, A) => Bool }: Bool = {
695774
def go(list: List[A]): Bool = {
696775
list match {
697-
case Nil() => true
698-
case Cons(x, Nil()) => true
699776
case Cons(x, Cons(y, rest)) =>
700777
val next = Cons(y, rest) // Future work: Replace this by an @-pattern!
701-
compare(x, y) && go(next)
778+
lessOrEqual(x, y) && go(next)
779+
case _ => true
702780
}
703781
}
704782
go(list)

0 commit comments

Comments
 (0)