Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/Array.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LiberalTypeSynonyms #-}

-- {-# LANGUAGE Strict #-}

Expand All @@ -15,6 +16,9 @@ module Array

-- * Construction and querying
, alloc, make, generate, generate_par, generate_par_m, makeArray
, flattenCallback, makeCallback, biJoinAllocAffine, allocScratchAffine
, biJoinAlloc, allocScratch

, copy, copy_par, copy_par_m
, size, get, set, slice, append
, splitAt
Expand Down Expand Up @@ -95,9 +99,47 @@ makeArray = make
#endif

{-# INLINE free #-}
free :: HasPrim a => Array a -. ()
free :: Array a -. ()
free = Unsafe.toLinear (\_ -> ())

{-# INLINE flattenCallback #-}
flattenCallback :: (forall c. (Array b -. Ur c) -. Array a -. Ur c) -. Array a -. Array b
flattenCallback f arr = unur (f ur arr)

{-# INLINE makeCallback #-}
makeCallback :: (Array b -. Array a) -. (Array a -. Ur c) -. Array b -. Ur c
makeCallback direct k arr = k (direct arr)

{-# INLINE biJoinAllocAffine #-}
biJoinAllocAffine :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. Array dsts) -> Array srcs -. Array dsts
biJoinAllocAffine i a f = flattenCallback (\cont src -> alloc i a (\tmp -> makeCallback (f tmp) cont src))

-- efficient implementation of above
{-# INLINE allocScratchAffine #-}
allocScratchAffine :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. Array dsts) -> Array srcs -. Array dsts
allocScratchAffine i a f arr = f arr (makeArray i a)

{-# INLINE biJoinAlloc #-}
biJoinAlloc :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. (Array dsts, Array tmpdsts)) -> Array srcs -. Array dsts
biJoinAlloc i a f =
let
g tmp src =
let
!(dst, tmp') = f tmp src
in
case free tmp' of !() -> dst
in
flattenCallback (\cont src -> alloc i a (\tmp -> makeCallback (g tmp) cont src))

-- efficient implementation of above
{-# INLINE allocScratch #-}
allocScratch :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. (Array dsts, Array tmpdsts)) -> Array srcs -. Array dsts
allocScratch i a f arr =
let
!(dst, tmp) = f arr (makeArray i a)
in case free tmp of !() -> dst


--------------------------------------------------------------------------------
-- Parallel operations
--------------------------------------------------------------------------------
Expand Down
9 changes: 6 additions & 3 deletions src/DpsMergeSort4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@ msortInplace src tmp = go src tmp where
msort' :: (Show a, HasPrimOrd a) => a -> A.Array a -. A.Array a
msort' anyVal src =
let !(Ur len, src') = A.size2 src
!(src'', _tmp) = msortInplace src' (A.make len anyVal) in
case A.free _tmp of !() -> src''
-- The old implementation with unsafe operation make. In case we want to look into performance.
-- !(src'', _tmp) = msortInplace src' (A.make len anyVal) in
-- case A.free _tmp of !() -> src''
!src'' = A.allocScratch len anyVal msortInplace src' in
src''
{-# INLINE msort' #-}

-- finally, the top-level merge sort function -- TODO: use A.get2/A.size2 for linearity
-- finally, the top-level merge sort function
{-@ msort :: { xs:(A.Array a) | left xs == 0 && right xs == size xs }
-> { ys:_ | toBag xs == toBag ys && isSorted' ys &&
A.size xs == A.size ys && token xs == token ys } @-}
Expand Down
Loading