Skip to content

Commit 6d23e46

Browse files
committed
Following up on PR 1296, implement both APIs for search_sorted.
Context: #1296 (comment) The more basic API, still named `search_sorted`, returns a `Post n`. The idea is that the elements of `xs` are fence sections, and we find the position between them (inclusive on either end) where `x` falls in the ordering. In terms of this, we now define `search_sorted_exact` (better name?), which returns a `Maybe n`, which is the index of an element of `xs` that equals `x` exactly, or `Nothing` if such does not exist. Also reorder the prelude slightly to try to both maintain semantic groupings and respect name resolution dependencies.
1 parent e446273 commit 6d23e46

File tree

4 files changed

+73
-38
lines changed

4 files changed

+73
-38
lines changed

lib/prelude.dx

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -591,34 +591,6 @@ def i_to_n(x:Int) -> Maybe Nat =
591591
then Nothing
592592
else Just $ unsafe_i_to_n x
593593

594-
'## Fencepost index sets
595-
596-
struct Post(segment:Type) =
597-
val : Nat
598-
599-
instance Ix(Post segment) given (segment|Ix)
600-
def size'() = size segment + 1
601-
def ordinal(i) = i.val
602-
def unsafe_from_ordinal(i) = Post(i)
603-
604-
def left_post(i:n) -> Post n given (n|Ix) =
605-
unsafe_from_ordinal(n=Post n, ordinal i)
606-
607-
def right_post(i:n) -> Post n given (n|Ix) =
608-
unsafe_from_ordinal(n=Post n, ordinal i + 1)
609-
610-
interface NonEmpty(n|Ix)
611-
first_ix : n
612-
613-
def last_ix() ->> n given (n|NonEmpty) =
614-
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))
615-
616-
instance NonEmpty(Post n) given (n|Ix)
617-
first_ix = unsafe_from_ordinal(n=Post n, 0)
618-
619-
instance NonEmpty(())
620-
first_ix = unsafe_from_ordinal(0)
621-
622594
'### Monoid
623595
A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element.
624596
This is a very useful and general calls of things.
@@ -901,6 +873,12 @@ instance Ix(Maybe a) given (a|Ix)
901873
False -> Just $ unsafe_from_ordinal o
902874
True -> Nothing
903875

876+
interface NonEmpty(n|Ix)
877+
first_ix : n
878+
879+
instance NonEmpty(())
880+
first_ix = unsafe_from_ordinal(0)
881+
904882
instance NonEmpty(Bool)
905883
first_ix = unsafe_from_ordinal 0
906884

@@ -918,6 +896,40 @@ instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
918896
instance NonEmpty(Maybe a) given (a|Ix)
919897
first_ix = unsafe_from_ordinal 0
920898

899+
'## Fencepost index sets
900+
901+
struct Post(segment:Type) =
902+
val : Nat
903+
904+
instance Ix(Post segment) given (segment|Ix)
905+
def size'() = size segment + 1
906+
def ordinal(i) = i.val
907+
def unsafe_from_ordinal(i) = Post(i)
908+
909+
def left_post(i:n) -> Post n given (n|Ix) =
910+
unsafe_from_ordinal(n=Post n, ordinal i)
911+
912+
def right_post(i:n) -> Post n given (n|Ix) =
913+
unsafe_from_ordinal(n=Post n, ordinal i + 1)
914+
915+
def left_fence(p:Post n) -> Maybe n given (n|Ix) =
916+
ix = ordinal p
917+
if ix == 0
918+
then Nothing
919+
else Just $ unsafe_from_ordinal $ ix -| 1
920+
921+
def right_fence(p:Post n) -> Maybe n given (n|Ix) =
922+
ix = ordinal p
923+
if ix == size n
924+
then Nothing
925+
else Just $ unsafe_from_ordinal ix
926+
927+
def last_ix() ->> n given (n|NonEmpty) =
928+
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))
929+
930+
instance NonEmpty(Post n) given (n|Ix)
931+
first_ix = unsafe_from_ordinal(n=Post n, 0)
932+
921933
def scan(
922934
init:a,
923935
body:(n, a)->(a,b)
@@ -2016,27 +2028,45 @@ instance Arbitrary(Fin n) given (n)
20162028

20172029
'### Searching
20182030

2019-
'returns the highest index `i` such that `xs.i <= x`
2031+
'Returns the bucket of `x` assuming boundaries `xs` as a `Post n`.
2032+
The boundaries must already be sorted, and are inclusive on the left.
2033+
2034+
'In other words, if there is an index `i` such that `xs.i <= x`,
2035+
returns the `right_post` of the highest such index; otherwise returns
2036+
`first_ix : Post n`, which is also the `left_post` of the minimum `i`.
20202037

2021-
def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
2038+
'This is equivalent to the right-biased formulation: if an index `i`
2039+
exists such that `x < xs.i`, returns the `left_post` of the least such
2040+
`i`, otherwise returns `last_ix : Post n`, i.e., the `right_post` of
2041+
the maximum `i`.
2042+
2043+
def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) =
20222044
if size n == 0
2023-
then Nothing
2045+
then first_ix
20242046
else if x < xs[from_ordinal 0]
2025-
then Nothing
2047+
then first_ix
20262048
else
20272049
low <- with_state(0)
20282050
high <- with_state(size n)
20292051
_ <- iter
20302052
numLeft = n_to_i (get high) - n_to_i (get low)
20312053
if numLeft == 1
2032-
then Done $ Just $ from_ordinal $ get low
2054+
then Done $ right_post $ from_ordinal $ get low
20332055
else
20342056
centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2)
20352057
if x < xs[from_ordinal centerIx]
20362058
then high := centerIx
20372059
else low := centerIx
20382060
Continue
20392061

2062+
'If `i` exists such that `xs.i == x`, returns `Just` of the largest
2063+
such `i`, otherwise returns `Nothing`.
2064+
2065+
def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
2066+
case left_fence(search_sorted(xs, x)) of
2067+
Just i -> if xs[i] == x then Just i else Nothing
2068+
Nothing -> Nothing
2069+
20402070
'### min / max etc
20412071

20422072
def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
@@ -2318,8 +2348,7 @@ def lines(source:String) -> List String =
23182348
-- cdf should include 0.0 but not 1.0
23192349
def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) =
23202350
r = rand key
2321-
case search_sorted(cdf, r) of
2322-
Just(i) -> i
2351+
from_just $ left_fence $ search_sorted(cdf, r)
23232352

23242353
def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs
23252354

lib/set.dx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def set_intersect(
8282
UnsafeAsSet(nx, xs) = sx
8383
UnsafeAsSet(ny, ys) = sy
8484
-- This could be done in O(nx + ny) instead of O(nx log ny).
85-
isInYs = \x. case search_sorted ys x of
85+
isInYs = \x. case search_sorted_exact ys x of
8686
Just x -> True
8787
Nothing -> False
8888
AsList(n', intersection) = filter xs isInYs
@@ -100,7 +100,7 @@ struct Element(set:(Set a)) given (a|Ord) =
100100
-- type), but maybe it's easier to read if it's explicit.
101101
def member(x:a, set:(Set a)) -> Maybe (Element set) given (a|Ord) =
102102
UnsafeAsSet(_, elts) = set
103-
case search_sorted elts x of
103+
case search_sorted_exact elts x of
104104
Just n -> Just $ Element(ordinal n)
105105
Nothing -> Nothing
106106

lib/stats.dx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ instance OrderedDist(Binomial, Nat, Float)
166166
lpdf = for i:(Fin tp1). ln $ density d (ordinal i)
167167
cdf = cdf_for_categorical lpdf
168168
mi = search_sorted cdf q
169-
ordinal $ from_just mi
169+
ordinal $ from_just $ left_fence mi
170170

171171

172172
'### Exponential distribution

tests/sort-tests.dx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ import sort
55
:p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
66
> True
77

8+
:p
9+
xs = [1,2,4]
10+
for i:(Fin 6).
11+
search_sorted_exact(xs, ordinal i)
12+
> [Nothing, (Just 0), (Just 1), Nothing, (Just 2), Nothing]
13+
814
'### Lexical Sorting Tests
915

1016
:p "aaa" < "bbb"

0 commit comments

Comments
 (0)