Skip to content

Commit c66e51d

Browse files
committed
add a bunch of missing INLIN(ABL)E in the parallel-merge code
1 parent 1ef22da commit c66e51d

File tree

6 files changed

+191
-130
lines changed

6 files changed

+191
-130
lines changed

.github/workflows/build-test-linear.yaml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,14 @@ jobs:
145145
- name: Make sure benchrunner builds and runs
146146
run: |
147147
cabal build benchrunner
148-
cabal run benchrunner -- 5 Insertionsort Seq 100
149-
cabal run benchrunner -- 5 Mergesort Seq 100
150-
cabal run benchrunner -- 5 Mergesort Par 100 +RTS -N2
151-
cabal run benchrunner -- 5 "VectorSort Insertionsort" Seq 100
152-
cabal run benchrunner -- 5 "VectorSort Mergesort" Seq 100
153-
cabal run benchrunner -- 5 "VectorSort Quicksort" Seq 100
154-
cabal run benchrunner -- 5 "CSort Insertionsort" Seq 100
155-
cabal run benchrunner -- 5 "CSort Mergesort" Seq 100
156-
cabal run benchrunner -- 5 "CSort Quicksort" Seq 100
148+
cabal run benchrunner -- 5 Insertionsort Seq 1000
149+
cabal run benchrunner -- 5 Quicksort Seq 1000
150+
cabal run benchrunner -- 5 Mergesort Seq 1000
151+
cabal run benchrunner -- 5 Mergesort Par 1000 +RTS -N1
152+
cabal run benchrunner -- 5 Mergesort Par 1000 +RTS -N2
153+
cabal run benchrunner -- 5 "VectorSort Insertionsort" Seq 1000
154+
cabal run benchrunner -- 5 "VectorSort Mergesort" Seq 1000
155+
cabal run benchrunner -- 5 "VectorSort Quicksort" Seq 1000
156+
cabal run benchrunner -- 5 "CSort Insertionsort" Seq 1000
157+
cabal run benchrunner -- 5 "CSort Mergesort" Seq 1000
158+
cabal run benchrunner -- 5 "CSort Quicksort" Seq 1000

src/Array.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ generate_loop arr idx end f =
179179
copy2_par :: HasPrim a => Int -> Int -> Int -> Array a -. Array a -. (Array a, Array a)
180180
copy2_par src_offset0 dst_offset0 len0 =
181181
Unsafe.toLinear (\src0 -> Unsafe.toLinear (\dst0 -> (src0, copy_par src0 src_offset0 dst0 dst_offset0 len0)))
182+
{-# INLINABLE copy2_par #-}
182183

183184
--TODO: src_offset0 and dst_offset0 are not respected.
184185
{- @ ignore copy_par @-}
@@ -205,6 +206,7 @@ copy_par src0 src_offset0 dst0 dst_offset0 len0 = copy_par' src0 src_offset0 dst
205206
#else
206207
copy_par src0 src_offset0 dst0 dst_offset0 len0 = copy src0 src_offset0 dst0 dst_offset0 len0
207208
#endif
209+
{-# INLINABLE copy_par #-}
208210

209211
--TODO: src_offset0 and dst_offset0 are not respected.
210212
{-@ ignore copy_par_m @-}
@@ -223,6 +225,7 @@ copy_par_m !src0 src_offset0 !dst0 dst_offset0 !len0 = copy_par_m' src0 src_offs
223225
!right <- copy_par_m' src_r 0 dst_r 0 (len-half)
224226
!left <- P.get left_f
225227
pure $ append left right
228+
{-# INLINABLE copy_par_m #-}
226229

227230
-- {-@ ignore foldl1_par @-}
228231
-- foldl1_par :: Int -> (a -> a -> a) -> a -> Array a -> a

src/Array/Mutable.hs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{-# LANGUAGE MagicHash #-}
33
{-# LANGUAGE BangPatterns #-}
44

5+
{-# OPTIONS_GHC -Wno-name-shadowing #-}
56

67
-- The Strict pragma is not just for performance, it's necessary for correctness.
78
-- Without it, this implementation contains a bug related to some thunk/effect
@@ -146,8 +147,8 @@ splitAt m = Unsafe.toLinear (\xs -> (slice xs 0 m, slice xs m (size xs)))
146147
append :: Array a -. Array a -. Array a
147148
append xs ys =
148149
let !res = Unsafe.toLinear (\xs -> case xs of
149-
(Array l1 _r1 !a1) -> Unsafe.toLinear (\ys -> case ys of
150-
(Array _l2 r2 _a2) -> Array l1 r2 a1)) xs ys
150+
(Array !l1 _r1 !a1) -> Unsafe.toLinear (\ys -> case ys of
151+
(Array _l2 !r2 _a2) -> Array l1 r2 a1)) xs ys
151152
in res
152153

153154
-- token xs == token ys

src/DpsMergePar.hs

Lines changed: 80 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
{-# LANGUAGE BangPatterns #-}
33
{-# LANGUAGE CPP #-}
44

5+
{-# OPTIONS_GHC -Wno-name-shadowing #-}
6+
57
module DpsMergePar where
68

79
import qualified Language.Haskell.Liquid.Bag as B
@@ -349,72 +351,89 @@ merge_par' :: (Show a, HasPrimOrd a, NFData a) =>
349351
merge_par' :: (Show a, HasPrimOrd a) =>
350352
#endif
351353
A.Array a -. (A.Array a -. (A.Array a -. ((A.Array a, A.Array a), A.Array a)))
352-
merge_par' !src1 !src2 !dst =
353-
let !(Ur n3, dst') = A.size2 dst in
354-
if n3 < goto_seqmerge
355-
then merge' 0 0 0 src1 src2 dst'
356-
? toProof (merge_par_func src1 src2 dst === merge_func src1 src2 dst 0 0 0)
357-
else let !(Ur n1, src1') = A.size2 src1
358-
!(Ur n2, src2') = A.size2 src2
359-
in if n1 == 0
360-
then let !(src2'1, dst'') = A.copy2_par 0 0 n2 src2' dst'
361-
in ((src1', src2'1), dst'')
362-
else if n2 == 0
363-
then let !(src1'1, dst'') = A.copy2_par 0 0 n1 src1' dst'
364-
in ((src1'1, src2'), dst'')
365-
else let mid1 = n1 `div` 2
366-
!(Ur pivot, src1'1) = A.get2 mid1 src1'
367-
!(Ur mid2, src2'1) = binarySearch pivot src2' -- src2[mid2] must <= all src1[mid1+1..]
368-
-- must >= all src1[0..mid1]
369-
!(src1_l, src1_cr) = A.splitAt mid1 src1'1
370-
!(src1_c, src1_r) = A.splitAt 1 src1_cr
371-
!(src2_l, src2_r) = A.splitAt mid2 src2'1
372-
373-
!(dst_l, dst_cr) = A.splitAt (mid1+mid2) dst'
374-
!(dst_c, dst_r) = A.splitAt 1 dst_cr
375-
!dst_c' = A.setLin 0 pivot dst_c
376-
377-
!(((src1_l',src2_l'), dst_l'), ((src1_r',src2_r'), dst_r'))
378-
= (merge_par' src1_l src2_l dst_l) .||. (merge_par' src1_r src2_r dst_r)
379-
{-
380-
(left, right) = tuple2 (merge_par' src1_l src2_l) dst_l
381-
-- ( ( (src1_l ? lem_isSortedBtw_slice src1'1 0 mid1)
382-
-- , (src2_l ? lem_isSortedBtw_slice src2'1 0 mid2) )
383-
-- , dst_l )
384-
(merge_par' src1_r src2_r) dst_r
385-
-- ( ( (src1_r ? lem_isSortedBtw_slice src1'1 mid1 n1
386-
-- ? lem_isSortedBtw_slice src1_cr 1 (n1-mid1))
387-
-- , (src2_r ? lem_isSortedBtw_slice src2'1 mid2 n2) )
388-
-- , dst_r )
389-
-}
390-
!src1_cr' = A.append src1_c src1_r'
391-
!src1'3 = A.append src1_l' src1_cr'
392-
!src2'3 = A.append src2_l' src2_r'
393-
!dst'' = A.append dst_l' dst_c'
394-
!dst''' = A.append dst'' dst_r'
395-
in ((src1'3, src2'3), dst''')
354+
merge_par' !src1 !src2 !dst = go src1 src2 dst where
355+
{-@ go :: xs1:(Array a) -> { xs2:(Array a) | token xs1 == token xs2 }
356+
-> { zs :(Array a) | size xs1 + size xs2 == size zs }
357+
-> { t:_ | snd t == merge_par_func xs1 xs2 zs &&
358+
token (fst (fst t)) == token xs1 && token (snd (fst t)) == token xs2 &&
359+
left (fst (fst t)) == left xs1 && right (fst (fst t)) == right xs1 &&
360+
left (snd (fst t)) == left xs2 && right (snd (fst t)) == right xs2 &&
361+
size (snd t) == size zs && token (snd t) == token zs &&
362+
left (snd t) == left zs && right (snd t) == right zs } / [size xs1] @-}
363+
#ifdef MUTABLE_ARRAYS
364+
go :: (Show a, HasPrimOrd a, NFData a) =>
365+
#else
366+
go :: (Show a, HasPrimOrd a) =>
367+
#endif
368+
A.Array a -. (A.Array a -. (A.Array a -. ((A.Array a, A.Array a), A.Array a)))
369+
go src1 src2 dst =
370+
let !(Ur n3, dst') = A.size2 dst in
371+
if n3 < goto_seqmerge
372+
then merge' 0 0 0 src1 src2 dst'
373+
? toProof (merge_par_func src1 src2 dst === merge_func src1 src2 dst 0 0 0)
374+
else let !(Ur n1, src1') = A.size2 src1
375+
!(Ur n2, src2') = A.size2 src2
376+
in if n1 == 0
377+
then let !(src2'1, dst'') = A.copy2_par 0 0 n2 src2' dst'
378+
in ((src1', src2'1), dst'')
379+
else if n2 == 0
380+
then let !(src1'1, dst'') = A.copy2_par 0 0 n1 src1' dst'
381+
in ((src1'1, src2'), dst'')
382+
else let mid1 = n1 `div` 2
383+
!(Ur pivot, src1'1) = A.get2 mid1 src1'
384+
!(Ur mid2, src2'1) = binarySearch pivot src2' -- src2[mid2] must <= all src1[mid1+1..]
385+
-- must >= all src1[0..mid1]
386+
!(src1_l, src1_cr) = A.splitAt mid1 src1'1
387+
!(src1_c, src1_r) = A.splitAt 1 src1_cr
388+
!(src2_l, src2_r) = A.splitAt mid2 src2'1
389+
390+
!(dst_l, dst_cr) = A.splitAt (mid1+mid2) dst'
391+
!(dst_c, dst_r) = A.splitAt 1 dst_cr
392+
!dst_c' = A.setLin 0 pivot dst_c
393+
394+
!(((src1_l',src2_l'), dst_l'), ((src1_r',src2_r'), dst_r'))
395+
= (go src1_l src2_l dst_l) .||. (go src1_r src2_r dst_r)
396+
{-
397+
(left, right) = tuple2 (merge_par' src1_l src2_l) dst_l
398+
-- ( ( (src1_l ? lem_isSortedBtw_slice src1'1 0 mid1)
399+
-- , (src2_l ? lem_isSortedBtw_slice src2'1 0 mid2) )
400+
-- , dst_l )
401+
(merge_par' src1_r src2_r) dst_r
402+
-- ( ( (src1_r ? lem_isSortedBtw_slice src1'1 mid1 n1
403+
-- ? lem_isSortedBtw_slice src1_cr 1 (n1-mid1))
404+
-- , (src2_r ? lem_isSortedBtw_slice src2'1 mid2 n2) )
405+
-- , dst_r )
406+
-}
407+
!src1_cr' = A.append src1_c src1_r'
408+
!src1'3 = A.append src1_l' src1_cr'
409+
!src2'3 = A.append src2_l' src2_r'
410+
!dst'' = A.append dst_l' dst_c'
411+
!dst''' = A.append dst'' dst_r'
412+
in ((src1'3, src2'3), dst''')
413+
{-# INLINE merge_par' #-}
396414

397415
{-@ binarySearch :: query:_ -> ls:_
398416
-> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
399417
snd tup == ls && (unur (fst tup), snd tup) = (binarySearch_func ls query, ls) } @-}
400418
binarySearch :: HasPrimOrd a => a -> A.Array a -. (Ur Int, A.Array a) -- must be able to return out of bounds
401419
binarySearch query ls = let !(Ur n, ls') = A.size2 ls
402-
in binarySearch' query 0 n ls'
403-
404-
{-@ binarySearch' :: query:_ -> lo:Nat
405-
-> { hi:Nat | lo <= hi }
406-
-> { ls:_ | hi <= size ls }
407-
-> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
408-
snd tup == ls &&
409-
(unur (fst tup), snd tup) = (binarySearch_func' ls query lo hi, ls) } / [hi-lo] @-}
410-
binarySearch' :: HasPrimOrd a => a -> Int -> Int -> A.Array a -. (Ur Int, A.Array a)
411-
binarySearch' query lo hi ls = if lo == hi
412-
then (Ur lo, ls)
413-
else let mid = lo + (hi - lo) `div` 2
414-
!(Ur midElt, ls') = A.get2 mid ls
415-
in if query < midElt
416-
then binarySearch' query lo mid ls'
417-
else binarySearch' query (mid+1) hi ls'
420+
in binarySearch' query 0 n ls' where
421+
422+
{-@ binarySearch' :: query:_ -> lo:Nat
423+
-> { hi:Nat | lo <= hi }
424+
-> { ls:_ | hi <= size ls }
425+
-> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
426+
snd tup == ls &&
427+
(unur (fst tup), snd tup) = (binarySearch_func' ls query lo hi, ls) } / [hi-lo] @-}
428+
binarySearch' :: HasPrimOrd a => a -> Int -> Int -> A.Array a -. (Ur Int, A.Array a)
429+
binarySearch' query lo hi ls = if lo == hi
430+
then (Ur lo, ls)
431+
else let mid = lo + (hi - lo) `div` 2
432+
!(Ur midElt, ls') = A.get2 mid ls
433+
in if query < midElt
434+
then binarySearch' query lo mid ls'
435+
else binarySearch' query (mid+1) hi ls'
436+
{-# INLINE binarySearch #-}
418437

419438
{-@ merge_par :: { xs1:(Array a) | isSorted' xs1 }
420439
-> { xs2:(Array a) | isSorted' xs2 && token xs1 == token xs2 && right xs1 == left xs2 }
@@ -426,9 +445,6 @@ binarySearch' query lo hi ls = if lo == hi
426445
left (snd t) == left zs && right (snd t) == right zs &&
427446
size (fst t) == size xs1 + size xs2 &&
428447
size (snd t) == size zs } / [size xs1] @-}
429-
{-# INLINE merge_par #-}
430-
{-# SPECIALISE merge_par :: A.Array Float -. A.Array Float -. A.Array Float -. (A.Array Float, A.Array Float) #-}
431-
{-# SPECIALISE merge_par :: A.Array Int -. A.Array Int -. A.Array Int -. (A.Array Int, A.Array Int) #-}
432448
#ifdef MUTABLE_ARRAYS
433449
merge_par :: (Show a, HasPrimOrd a, NFData a) =>
434450
#else
@@ -441,3 +457,4 @@ merge_par !src1 !src2 !dst =
441457
in (src', dst')
442458
? lem_merge_par_func_sorted src1 src2 dst
443459
? lem_merge_par_func_equiv src1 src2 dst
460+
{-# INLINABLE merge_par #-}

src/DpsMergeParSeqFallback.hs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
{-# LANGUAGE BangPatterns #-}
33
{-# LANGUAGE CPP #-}
44

5+
{-# OPTIONS_GHC -Wno-name-shadowing #-}
6+
57
module DpsMergeParSeqFallback where
68

79
import qualified Language.Haskell.Liquid.Bag as B
@@ -247,25 +249,41 @@ merge' :: HasPrimOrd a =>
247249
Int -> Int -> Int ->
248250
A.Array a -. A.Array a -. A.Array a -.
249251
((A.Array a, A.Array a), A.Array a)
250-
merge' i1 i2 j !src1 !src2 !dst =
251-
let !(Ur len1, src1') = A.size2 src1
252-
!(Ur len2, src2') = A.size2 src2 in
253-
if i1 >= len1
254-
then
255-
let !(src2'1, dst') = A.copy2_par i2 j (len2-i2) src2' dst in ((src1', src2'1), dst')
256-
else if i2 >= len2
257-
then
258-
let !(src1'1, dst') = A.copy2_par i1 j (len1-i1) src1' dst in ((src1'1, src2'), dst')
259-
else
260-
let !(Ur v1, src1'1) = A.get2 i1 src1'
261-
!(Ur v2, src2'1) = A.get2 i2 src2' in
262-
if v1 < v2
263-
then let dst' = A.setLin j v1 dst
264-
!(src_tup, dst'') = merge' (i1 + 1) i2 (j + 1) src1'1 src2'1 dst' in
265-
(src_tup, dst'')
266-
else let dst' = A.setLin j v2 dst
267-
!(src_tup, dst'') = merge' i1 (i2 + 1) (j + 1) src1'1 src2'1 dst' in
268-
(src_tup, dst'')
252+
merge' i1 i2 j !src1 !src2 !dst = go i1 i2 j src1 src2 dst where
253+
{-@ go :: i1:Nat -> i2:Nat -> { j:Nat | i1 + i2 == j }
254+
-> { xs1:(Array a) | i1 <= size xs1 }
255+
-> { xs2:(Array a) | token xs1 == token xs2 && i2 <= size xs2 }
256+
-> { zs:(Array a) | size xs1 + size xs2 == size zs && j <= size zs }
257+
-> { t:_ | t == ((xs1, xs2), merge_func xs1 xs2 zs i1 i2 j) &&
258+
token (fst (fst t)) == token xs1 && token (snd (fst t)) == token xs2 &&
259+
left (fst (fst t)) == left xs1 && right (fst (fst t)) == right xs1 &&
260+
left (snd (fst t)) == left xs2 && right (snd (fst t)) == right xs2 &&
261+
size (snd t) == size zs && token (snd t) == token zs &&
262+
left (snd t) == left zs && right (snd t) == right zs } / [size zs - j] @-}
263+
go :: HasPrimOrd a =>
264+
Int -> Int -> Int ->
265+
A.Array a -. A.Array a -. A.Array a -.
266+
((A.Array a, A.Array a), A.Array a)
267+
go !i1 !i2 !j src1 src2 dst =
268+
let !(Ur len1, !src1') = A.size2 src1
269+
!(Ur len2, !src2') = A.size2 src2 in
270+
if i1 >= len1
271+
then
272+
let !(src2'1, dst') = A.copy2_par i2 j (len2-i2) src2' dst in ((src1', src2'1), dst')
273+
else if i2 >= len2
274+
then
275+
let !(src1'1, dst') = A.copy2_par i1 j (len1-i1) src1' dst in ((src1'1, src2'), dst')
276+
else
277+
let !(Ur v1, !src1'1) = A.get2 i1 src1'
278+
!(Ur v2, !src2'1) = A.get2 i2 src2' in
279+
if v1 < v2
280+
then let !dst' = A.setLin j v1 dst
281+
!(src_tup, dst'') = go (i1 + 1) i2 (j + 1) src1'1 src2'1 dst' in
282+
(src_tup, dst'')
283+
else let !dst' = A.setLin j v2 dst
284+
!(src_tup, dst'') = go i1 (i2 + 1) (j + 1) src1'1 src2'1 dst' in
285+
(src_tup, dst'')
286+
{-# INLINE merge' #-}
269287

270288
{-@ merge :: { xs1:(Array a) | isSorted' xs1 }
271289
-> { xs2:(Array a) | isSorted' xs2 && token xs1 == token xs2 }
@@ -278,12 +296,8 @@ merge' i1 i2 j !src1 !src2 !dst =
278296
left (snd (fst t)) == left xs2 && right (snd (fst t)) == right xs2 &&
279297
left (snd t) == left zs && right (snd t) == right zs &&
280298
size (snd t) == size zs } @-}
281-
{-# INLINE merge #-}
282-
{-# SPECIALISE merge :: A.Array Float -. A.Array Float -. A.Array Float
283-
-. ((A.Array Float, A.Array Float), A.Array Float) #-}
284-
{-# SPECIALISE merge :: A.Array Int -. A.Array Int -. A.Array Int
285-
-. ((A.Array Int, A.Array Int), A.Array Int) #-}
286299
merge :: HasPrimOrd a => A.Array a -. A.Array a -. A.Array a -. ((A.Array a, A.Array a), A.Array a)
287300
merge src1 src2 dst = merge' 0 0 0 src1 src2 dst -- the 0's are relative to the current
288301
? lem_merge_func_sorted src1 src2 dst 0 0 0 -- slices, not absolute indices
289302
? lem_merge_func_equiv src1 src2 dst 0 0 0
303+
{-# INLINABLE merge #-}

0 commit comments

Comments
 (0)