Skip to content

Commit 70915f8

Browse files
authored
Sets and index sets for sets. (#737)
1 parent 6ecf834 commit 70915f8

File tree

6 files changed

+287
-14
lines changed

6 files changed

+287
-14
lines changed

lib/prelude.dx

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,6 @@ def argFilter {a n} (condition:a->Bool) (xs:n=>a) : List n =
11381138
if condition xs.i
11391139
then append list i
11401140

1141-
11421141
'## Isomorphisms
11431142

11441143
data Iso a b = MkIso { fwd: a -> b & bwd: b -> a }
@@ -1784,6 +1783,38 @@ def argscan {n o} (comp:o->o->Bool) (xs:n=>o) : n =
17841783
def argmin {n o} [Ord o] (xs:n=>o) : n = argscan (<) xs
17851784
def argmax {n o} [Ord o] (xs:n=>o) : n = argscan (>) xs
17861785

1786+
def lexicalOrder {n} [Ord n]
1787+
(compareElements:n->n->Bool)
1788+
(compareLengths:Int->Int->Bool)
1789+
((AsList nx xs):List n) ((AsList ny ys):List n) : Bool =
1790+
-- Orders Lists according to the order of their elements,
1791+
-- in the same way a dictionary does.
1792+
-- For example, this lets us sort Strings.
1793+
--
1794+
-- More precisely, it returns True iff compareElements xs.i ys.i is true
1795+
-- at the first location they differ.
1796+
--
1797+
-- This function operates serially and short-circuits
1798+
-- at the first difference. One could also write this
1799+
-- function as a parallel reduction, but it would be
1800+
-- wasteful in the case where there is an early difference,
1801+
-- because we can't short circuit.
1802+
iter \i.
1803+
case i == min nx ny of
1804+
True -> Done $ compareLengths nx ny
1805+
False ->
1806+
xi = xs.(unsafeFromOrdinal _ i)
1807+
yi = ys.(unsafeFromOrdinal _ i)
1808+
case compareElements xi yi of
1809+
True -> Done True
1810+
False -> case xi == yi of
1811+
True -> Continue
1812+
False -> Done False
1813+
1814+
instance {n} [Ord n] Ord (List n)
1815+
(>) = lexicalOrder (>) (>)
1816+
(<) = lexicalOrder (<) (<)
1817+
17871818
'### clip
17881819

17891820
def clip {a} [Ord a] ((low,high):(a&a)) (x:a) : a =

lib/set.dx

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import sort
2+
3+
4+
'### Monoidal enforcement of uniqueness in sorted lists
5+
6+
def last {n a} (xs:n=>a) : Maybe a =
7+
s = size n
8+
case s == 0 of
9+
True -> Nothing
10+
False -> Just xs.(unsafeFromOrdinal n (s - 1))
11+
12+
def first {n a} (xs:n=>a) : Maybe a =
13+
s = size n
14+
case s == 0 of
15+
True -> Nothing
16+
False -> Just xs.(unsafeFromOrdinal n 0)
17+
18+
def allExceptLast {n a} (xs:n=>a) : List a =
19+
shortSize = Fin (max 0 ((size n) - 1))
20+
allButLast = view i:shortSize. xs.(unsafeFromOrdinal _ (ordinal i))
21+
(AsList _ allButLast)
22+
23+
def mergeUniqueSortedLists {a} [Eq a] (xlist:List a) (ylist:List a) : List a =
24+
-- This function is associative, for use in a monoidal reduction.
25+
-- Assumes all xs are <= all ys.
26+
-- The element at the end of xs might equal the
27+
-- element at the beginning of ys. If so, this
28+
-- function removes the duplicate when concatenating the lists.
29+
(AsList nx xs) = xlist
30+
(AsList _ ys) = ylist
31+
case last xs of
32+
Nothing -> ylist
33+
Just last_x -> case first ys of
34+
Nothing -> xlist
35+
Just first_y -> case last_x == first_y of
36+
False -> concat [xlist, ylist]
37+
True -> concat [allExceptLast xs, ylist]
38+
39+
def removeDuplicatesFromSorted {n a} [Eq a] (xs:n=>a) : List a =
40+
xlists = for i:n. (AsList 1 [xs.i])
41+
reduce (AsList 0 []) mergeUniqueSortedLists xlists
42+
43+
44+
'### Sets
45+
46+
data Set a [Ord a] =
47+
-- Guaranteed to be in sorted order with unique elements,
48+
-- as long as no one else uses this constructor.
49+
-- Instead use the "toSet" function below.
50+
UnsafeAsSet n:Int elements:(Fin n => a)
51+
52+
def toSet {n a} [Ord a] (xs:n=>a) : Set a =
53+
sorted_xs = sort xs
54+
(AsList n' sorted_unique_xs) = removeDuplicatesFromSorted sorted_xs
55+
UnsafeAsSet n' sorted_unique_xs
56+
57+
def setSize {a} ((UnsafeAsSet n _):Set a) : Int = n
58+
59+
instance {a} [Eq a] Eq (Set a)
60+
(==) = \(UnsafeAsSet _ xs) (UnsafeAsSet _ ys).
61+
(AsList _ xs) == (AsList _ ys)
62+
63+
def setUnion {a}
64+
((UnsafeAsSet nx xs):Set a)
65+
((UnsafeAsSet ny ys):Set a) : Set a =
66+
combined = mergeSortedTables xs ys
67+
(AsList n' sorted_unique_xs) = removeDuplicatesFromSorted combined
68+
UnsafeAsSet _ sorted_unique_xs
69+
70+
def setIntersect {a}
71+
((UnsafeAsSet nx xs):Set a)
72+
((UnsafeAsSet ny ys):Set a) : Set a =
73+
-- This could be done in O(nx + ny) instead of O(nx log ny).
74+
isInYs = \x. case searchSorted ys x of
75+
Just x -> True
76+
Nothing -> False
77+
(AsList n' intersection) = filter isInYs xs
78+
UnsafeAsSet _ intersection
79+
80+
81+
'### Index set for sets of strings
82+
83+
-- Todo: Make polymorphic in type. Waiting on a bugfix.
84+
-- data SetIx a l:(Set a) [Ord a] =
85+
86+
data StringSetIx l:(Set String) =
87+
MkSetIx Int -- TODO: Use (Fin (setSize l)) instead.
88+
89+
instance {set} Ix (StringSetIx set)
90+
getSize = \(). setSize set
91+
ordinal = \(MkSetIx i). i
92+
unsafeFromOrdinal = \k. MkSetIx k
93+
94+
instance {set} Eq (StringSetIx set)
95+
(==) = \ix1 ix2. ordinal ix1 == ordinal ix2
96+
97+
-- Todo: Add an interface for converting to and from integer indices.
98+
-- Compiler can't handle the associated type yet.
99+
-- interface AssocIx n -- index sets where indices have data associated with them
100+
-- IxValueType : Type
101+
-- ixValue : n -> IxValueType n
102+
-- lookupIx : IxValueType n -> n
103+
104+
def stringToSetIx {set:Set String} (s:String) : Maybe (StringSetIx set) =
105+
(UnsafeAsSet n elements) = set
106+
maybeIx = searchSorted elements s
107+
case maybeIx of
108+
Nothing -> Nothing
109+
Just i -> Just $ MkSetIx (ordinal i)
110+
111+
def setIxToString {set:Set String} (ix:StringSetIx set) : String =
112+
(UnsafeAsSet n elements) = set
113+
elements.(unsafeFromOrdinal _ (ordinal ix))
114+

lib/sort.dx

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@ def concatTable {a b v} (leftin: a=>v) (rightin: b=>v) : ((a|b)=>v) =
1818
def mergeSortedTables {a m n} [Ord a] (xs:m=>a) (ys:n=>a) : ((m|n)=>a) =
1919
-- Possible improvements:
2020
-- 1) Using a SortedTable type.
21-
-- 2) Avoid initializing the return array.
22-
init = concatTable xs ys
21+
-- 2) Avoid needlessly initializing the return array.
22+
init = concatTable xs ys -- Initialize array of correct size.
2323
yieldState init \buf.
2424
withState (0, 0) \countrefs.
25-
for i.
25+
for i:(m|n).
2626
(cur_x, cur_y) = get countrefs
27-
noYsLeft = cur_y >= size n
28-
stillXsLeft = cur_x < size m
29-
cur_x_at_n = (unsafeFromOrdinal _ cur_x)
30-
cur_y_at_n = (unsafeFromOrdinal _ cur_y)
31-
xIsLess = xs.cur_x_at_n < ys.cur_y_at_n
32-
if noYsLeft || (stillXsLeft && xIsLess)
27+
if cur_y >= size n -- no ys left
3328
then
3429
countrefs := (cur_x + 1, cur_y)
35-
buf!i := xs.cur_x_at_n
30+
buf!i := xs.(unsafeFromOrdinal _ cur_x)
3631
else
37-
countrefs := (cur_x, cur_y + 1)
38-
buf!i := ys.cur_y_at_n
32+
if cur_x < size m -- still xs left
33+
then
34+
if xs.(unsafeFromOrdinal _ cur_x) <= ys.(unsafeFromOrdinal _ cur_y)
35+
then
36+
countrefs := (cur_x + 1, cur_y)
37+
buf!i := xs.(unsafeFromOrdinal _ cur_x)
38+
else
39+
countrefs := (cur_x, cur_y + 1)
40+
buf!i := ys.(unsafeFromOrdinal _ cur_y)
3941

4042
def mergeSortedLists {a} [Ord a] (AsList nx xs: List a) (AsList ny ys: List a) : List a =
4143
-- Need this wrapper because Dex can't automatically weaken

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \
125125
shadow-tests monad-tests io-tests exception-tests sort-tests \
126126
ad-tests parser-tests serialize-tests parser-combinator-tests \
127127
record-variant-tests typeclass-tests complex-tests trig-tests \
128-
linalg-tests
128+
linalg-tests set-tests
129129

130130
lib-names = diagram plot png
131131

tests/set-tests.dx

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import set
2+
3+
-- check order invariance.
4+
:p (toSet ["Bob", "Alice", "Charlie"]) == (toSet ["Charlie", "Bob", "Alice"])
5+
> True
6+
7+
-- check uniqueness.
8+
:p (toSet ["Bob", "Alice", "Alice", "Charlie"]) == (toSet ["Charlie", "Charlie", "Bob", "Alice"])
9+
> True
10+
11+
set1 = toSet ["Xeno", "Alice", "Bob"]
12+
set2 = toSet ["Bob", "Xeno", "Charlie"]
13+
14+
:p set1 == set2
15+
> False
16+
17+
:p setUnion set1 set2
18+
> (UnsafeAsSet 4 [ (AsList 5 "Alice")
19+
> , (AsList 3 "Bob")
20+
> , (AsList 7 "Charlie")
21+
> , (AsList 4 "Xeno") ])
22+
23+
:p setIntersect set1 set2
24+
> (UnsafeAsSet 2 [(AsList 3 "Bob"), (AsList 4 "Xeno")])
25+
26+
:p removeDuplicatesFromSorted ["Alice", "Alice", "Alice", "Bob", "Bob", "Charlie", "Charlie", "Charlie"]
27+
> (AsList 3 [(AsList 5 "Alice"), (AsList 3 "Bob"), (AsList 7 "Charlie")])
28+
29+
:p set1 == (setUnion set1 set1)
30+
> True
31+
32+
:p set1 == (setIntersect set1 set1)
33+
> True
34+
35+
'#### Empty set tests
36+
37+
emptyset = toSet ([]:(Fin 0)=>String)
38+
39+
:p emptyset == emptyset
40+
> True
41+
42+
:p emptyset == (setUnion emptyset emptyset)
43+
> True
44+
45+
:p emptyset == (setIntersect emptyset emptyset)
46+
> True
47+
48+
:p set1 == (setUnion set1 emptyset)
49+
> True
50+
51+
:p emptyset == (setIntersect set1 emptyset)
52+
> True
53+
54+
'### Set Index Set tests
55+
56+
names2 = toSet ["Bob", "Alice", "Charlie", "Alice"]
57+
58+
:p size (StringSetIx names2)
59+
> 3
60+
61+
-- Check that ordinal and unsafeFromOrdinal are inverses.
62+
roundTrip = for i:(StringSetIx names2).
63+
i == (unsafeFromOrdinal _ (ordinal i))
64+
:p all roundTrip
65+
> True
66+
67+
-- Check that index to string and string to index are inverses.
68+
roundTrip2 = for i:(StringSetIx names2).
69+
s = setIxToString i
70+
ix = stringToSetIx s
71+
i == fromJust ix
72+
:p all roundTrip2
73+
> True
74+
75+
setix : StringSetIx names2 = fromJust $ stringToSetIx "Bob"
76+
:p setix
77+
> (MkSetIx 1)
78+
79+
setix2 : StringSetIx names2 = fromJust $ stringToSetIx "Charlie"
80+
:p setix2
81+
> (MkSetIx 2)

tests/sort-tests.dx

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,48 @@ import sort
44
> True
55
:p isSorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
66
> True
7+
8+
9+
'### Lexical Sorting Tests
10+
11+
:p "aaa" < "bbb"
12+
> True
13+
14+
:p "aa" < "bbb"
15+
> True
16+
17+
:p "a" < "aa"
18+
> True
19+
20+
:p "aaa" > "bbb"
21+
> False
22+
23+
:p "aa" > "bbb"
24+
> False
25+
26+
:p "a" > "aa"
27+
> False
28+
29+
:p "a" < "aa"
30+
> True
31+
32+
:p ("": List Word8) > ("": List Word8)
33+
> False
34+
35+
:p ("": List Word8) < ("": List Word8)
36+
> False
37+
38+
:p "a" > "a"
39+
> False
40+
41+
:p "a" < "a"
42+
> False
43+
44+
:p "Thomas" < "Thompson"
45+
> True
46+
47+
:p "Thomas" > "Thompson"
48+
> False
49+
50+
:p isSorted $ sort ["Charlie", "Alice", "Bob", "Aaron"]
51+
> True

0 commit comments

Comments
 (0)