22{-# LANGUAGE BangPatterns #-}
33{-# LANGUAGE CPP #-}
44
5+ {-# OPTIONS_GHC -Wno-name-shadowing #-}
6+
57module DpsMergePar where
68
79import qualified Language.Haskell.Liquid.Bag as B
@@ -349,72 +351,89 @@ merge_par' :: (Show a, HasPrimOrd a, NFData a) =>
349351merge_par' :: (Show a , HasPrimOrd a ) =>
350352# endif
351353 A. Array a -. (A. Array a -. (A. Array a -. ((A. Array a, A. Array a), A. Array a)))
352- merge_par' ! src1 ! src2 ! dst =
353- let ! (Ur n3, dst') = A. size2 dst in
354- if n3 < goto_seqmerge
355- then merge' 0 0 0 src1 src2 dst'
356- ? toProof (merge_par_func src1 src2 dst === merge_func src1 src2 dst 0 0 0 )
357- else let ! (Ur n1, src1') = A. size2 src1
358- ! (Ur n2, src2') = A. size2 src2
359- in if n1 == 0
360- then let ! (src2'1, dst'') = A. copy2_par 0 0 n2 src2' dst'
361- in ((src1', src2'1), dst'')
362- else if n2 == 0
363- then let ! (src1'1, dst'') = A. copy2_par 0 0 n1 src1' dst'
364- in ((src1'1, src2'), dst'')
365- else let mid1 = n1 `div` 2
366- ! (Ur pivot, src1'1) = A. get2 mid1 src1'
367- ! (Ur mid2, src2'1) = binarySearch pivot src2' -- src2[mid2] must <= all src1[mid1+1..]
368- -- must >= all src1[0..mid1]
369- ! (src1_l, src1_cr) = A. splitAt mid1 src1'1
370- ! (src1_c, src1_r) = A. splitAt 1 src1_cr
371- ! (src2_l, src2_r) = A. splitAt mid2 src2'1
372-
373- ! (dst_l, dst_cr) = A. splitAt (mid1+ mid2) dst'
374- ! (dst_c, dst_r) = A. splitAt 1 dst_cr
375- ! dst_c' = A. setLin 0 pivot dst_c
376-
377- ! (((src1_l',src2_l'), dst_l'), ((src1_r',src2_r'), dst_r'))
378- = (merge_par' src1_l src2_l dst_l) .||. (merge_par' src1_r src2_r dst_r)
379- {-
380- (left, right) = tuple2 (merge_par' src1_l src2_l) dst_l
381- -- ( ( (src1_l ? lem_isSortedBtw_slice src1'1 0 mid1)
382- -- , (src2_l ? lem_isSortedBtw_slice src2'1 0 mid2) )
383- -- , dst_l )
384- (merge_par' src1_r src2_r) dst_r
385- -- ( ( (src1_r ? lem_isSortedBtw_slice src1'1 mid1 n1
386- -- ? lem_isSortedBtw_slice src1_cr 1 (n1-mid1))
387- -- , (src2_r ? lem_isSortedBtw_slice src2'1 mid2 n2) )
388- -- , dst_r )
389- -}
390- ! src1_cr' = A. append src1_c src1_r'
391- ! src1'3 = A. append src1_l' src1_cr'
392- ! src2'3 = A. append src2_l' src2_r'
393- ! dst'' = A. append dst_l' dst_c'
394- ! dst''' = A. append dst'' dst_r'
395- in ((src1'3, src2'3), dst''')
354+ merge_par' ! src1 ! src2 ! dst = go src1 src2 dst where
355+ {- @ go :: xs1:(Array a) -> { xs2:(Array a) | token xs1 == token xs2 }
356+ -> { zs :(Array a) | size xs1 + size xs2 == size zs }
357+ -> { t:_ | snd t == merge_par_func xs1 xs2 zs &&
358+ token (fst (fst t)) == token xs1 && token (snd (fst t)) == token xs2 &&
359+ left (fst (fst t)) == left xs1 && right (fst (fst t)) == right xs1 &&
360+ left (snd (fst t)) == left xs2 && right (snd (fst t)) == right xs2 &&
361+ size (snd t) == size zs && token (snd t) == token zs &&
362+ left (snd t) == left zs && right (snd t) == right zs } / [size xs1] @-}
363+ # ifdef MUTABLE_ARRAYS
364+ go :: (Show a , HasPrimOrd a , NFData a ) =>
365+ # else
366+ go :: (Show a , HasPrimOrd a ) =>
367+ # endif
368+ A. Array a -. (A. Array a -. (A. Array a -. ((A. Array a, A. Array a), A. Array a)))
369+ go src1 src2 dst =
370+ let ! (Ur n3, dst') = A. size2 dst in
371+ if n3 < goto_seqmerge
372+ then merge' 0 0 0 src1 src2 dst'
373+ ? toProof (merge_par_func src1 src2 dst === merge_func src1 src2 dst 0 0 0 )
374+ else let ! (Ur n1, src1') = A. size2 src1
375+ ! (Ur n2, src2') = A. size2 src2
376+ in if n1 == 0
377+ then let ! (src2'1, dst'') = A. copy2_par 0 0 n2 src2' dst'
378+ in ((src1', src2'1), dst'')
379+ else if n2 == 0
380+ then let ! (src1'1, dst'') = A. copy2_par 0 0 n1 src1' dst'
381+ in ((src1'1, src2'), dst'')
382+ else let mid1 = n1 `div` 2
383+ ! (Ur pivot, src1'1) = A. get2 mid1 src1'
384+ ! (Ur mid2, src2'1) = binarySearch pivot src2' -- src2[mid2] must <= all src1[mid1+1..]
385+ -- must >= all src1[0..mid1]
386+ ! (src1_l, src1_cr) = A. splitAt mid1 src1'1
387+ ! (src1_c, src1_r) = A. splitAt 1 src1_cr
388+ ! (src2_l, src2_r) = A. splitAt mid2 src2'1
389+
390+ ! (dst_l, dst_cr) = A. splitAt (mid1+ mid2) dst'
391+ ! (dst_c, dst_r) = A. splitAt 1 dst_cr
392+ ! dst_c' = A. setLin 0 pivot dst_c
393+
394+ ! (((src1_l',src2_l'), dst_l'), ((src1_r',src2_r'), dst_r'))
395+ = (go src1_l src2_l dst_l) .||. (go src1_r src2_r dst_r)
396+ {-
397+ (left, right) = tuple2 (merge_par' src1_l src2_l) dst_l
398+ -- ( ( (src1_l ? lem_isSortedBtw_slice src1'1 0 mid1)
399+ -- , (src2_l ? lem_isSortedBtw_slice src2'1 0 mid2) )
400+ -- , dst_l )
401+ (merge_par' src1_r src2_r) dst_r
402+ -- ( ( (src1_r ? lem_isSortedBtw_slice src1'1 mid1 n1
403+ -- ? lem_isSortedBtw_slice src1_cr 1 (n1-mid1))
404+ -- , (src2_r ? lem_isSortedBtw_slice src2'1 mid2 n2) )
405+ -- , dst_r )
406+ -}
407+ ! src1_cr' = A. append src1_c src1_r'
408+ ! src1'3 = A. append src1_l' src1_cr'
409+ ! src2'3 = A. append src2_l' src2_r'
410+ ! dst'' = A. append dst_l' dst_c'
411+ ! dst''' = A. append dst'' dst_r'
412+ in ((src1'3, src2'3), dst''')
413+ {-# INLINE merge_par' #-}
396414
397415{- @ binarySearch :: query:_ -> ls:_
398416 -> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
399417 snd tup == ls && (unur (fst tup), snd tup) = (binarySearch_func ls query, ls) } @-}
400418binarySearch :: HasPrimOrd a => a -> A. Array a -. (Ur Int , A. Array a ) -- must be able to return out of bounds
401419binarySearch query ls = let ! (Ur n, ls') = A. size2 ls
402- in binarySearch' query 0 n ls'
403-
404- {- @ binarySearch' :: query:_ -> lo:Nat
405- -> { hi:Nat | lo <= hi }
406- -> { ls:_ | hi <= size ls }
407- -> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
408- snd tup == ls &&
409- (unur (fst tup), snd tup) = (binarySearch_func' ls query lo hi, ls) } / [hi-lo] @-}
410- binarySearch' :: HasPrimOrd a => a -> Int -> Int -> A. Array a -. (Ur Int , A. Array a )
411- binarySearch' query lo hi ls = if lo == hi
412- then (Ur lo, ls)
413- else let mid = lo + (hi - lo) `div` 2
414- ! (Ur midElt, ls') = A. get2 mid ls
415- in if query < midElt
416- then binarySearch' query lo mid ls'
417- else binarySearch' query (mid+ 1 ) hi ls'
420+ in binarySearch' query 0 n ls' where
421+
422+ {- @ binarySearch' :: query:_ -> lo:Nat
423+ -> { hi:Nat | lo <= hi }
424+ -> { ls:_ | hi <= size ls }
425+ -> { tup:_ | 0 <= unur (fst tup) && unur (fst tup) <= size ls &&
426+ snd tup == ls &&
427+ (unur (fst tup), snd tup) = (binarySearch_func' ls query lo hi, ls) } / [hi-lo] @-}
428+ binarySearch' :: HasPrimOrd a => a -> Int -> Int -> A. Array a -. (Ur Int , A. Array a )
429+ binarySearch' query lo hi ls = if lo == hi
430+ then (Ur lo, ls)
431+ else let mid = lo + (hi - lo) `div` 2
432+ ! (Ur midElt, ls') = A. get2 mid ls
433+ in if query < midElt
434+ then binarySearch' query lo mid ls'
435+ else binarySearch' query (mid+ 1 ) hi ls'
436+ {-# INLINE binarySearch #-}
418437
419438{- @ merge_par :: { xs1:(Array a) | isSorted' xs1 }
420439 -> { xs2:(Array a) | isSorted' xs2 && token xs1 == token xs2 && right xs1 == left xs2 }
@@ -426,9 +445,6 @@ binarySearch' query lo hi ls = if lo == hi
426445 left (snd t) == left zs && right (snd t) == right zs &&
427446 size (fst t) == size xs1 + size xs2 &&
428447 size (snd t) == size zs } / [size xs1] @-}
429- {-# INLINE merge_par #-}
430- {-# SPECIALISE merge_par :: A.Array Float -. A.Array Float -. A.Array Float -. (A.Array Float, A.Array Float) #-}
431- {-# SPECIALISE merge_par :: A.Array Int -. A.Array Int -. A.Array Int -. (A.Array Int, A.Array Int) #-}
432448# ifdef MUTABLE_ARRAYS
433449merge_par :: (Show a , HasPrimOrd a , NFData a ) =>
434450# else
@@ -441,3 +457,4 @@ merge_par !src1 !src2 !dst =
441457 in (src', dst')
442458 ? lem_merge_par_func_sorted src1 src2 dst
443459 ? lem_merge_par_func_equiv src1 src2 dst
460+ {-# INLINABLE merge_par #-}
0 commit comments