From f87e324d6fefd583cbd31cd9111b523cb3743b8c Mon Sep 17 00:00:00 2001 From: jazullo Date: Mon, 24 Mar 2025 06:15:33 +0000 Subject: [PATCH] Use allocation proposal in mergesort --- src/Array.hs | 44 +++++++++++++++++++++++++++++++++++++++++++- src/DpsMergeSort4.hs | 9 ++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/Array.hs b/src/Array.hs index 562c4e6..f69350b 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE LiberalTypeSynonyms #-} -- {-# LANGUAGE Strict #-} @@ -15,6 +16,9 @@ module Array -- * Construction and querying , alloc, make, generate, generate_par, generate_par_m, makeArray + , flattenCallback, makeCallback, biJoinAllocAffine, allocScratchAffine + , biJoinAlloc, allocScratch + , copy, copy_par, copy_par_m , size, get, set, slice, append , splitAt @@ -95,9 +99,47 @@ makeArray = make #endif {-# INLINE free #-} -free :: HasPrim a => Array a -. () +free :: Array a -. () free = Unsafe.toLinear (\_ -> ()) +{-# INLINE flattenCallback #-} +flattenCallback :: (forall c. (Array b -. Ur c) -. Array a -. Ur c) -. Array a -. Array b +flattenCallback f arr = unur (f ur arr) + +{-# INLINE makeCallback #-} +makeCallback :: (Array b -. Array a) -. (Array a -. Ur c) -. Array b -. Ur c +makeCallback direct k arr = k (direct arr) + +{-# INLINE biJoinAllocAffine #-} +biJoinAllocAffine :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. Array dsts) -> Array srcs -. Array dsts +biJoinAllocAffine i a f = flattenCallback (\cont src -> alloc i a (\tmp -> makeCallback (f tmp) cont src)) + +-- efficient implementation of above +{-# INLINE allocScratchAffine #-} +allocScratchAffine :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. Array dsts) -> Array srcs -. Array dsts +allocScratchAffine i a f arr = f arr (makeArray i a) + +{-# INLINE biJoinAlloc #-} +biJoinAlloc :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. (Array dsts, Array tmpdsts)) -> Array srcs -. Array dsts +biJoinAlloc i a f = + let + g tmp src = + let + !(dst, tmp') = f tmp src + in + case free tmp' of !() -> dst + in + flattenCallback (\cont src -> alloc i a (\tmp -> makeCallback (g tmp) cont src)) + +-- efficient implementation of above +{-# INLINE allocScratch #-} +allocScratch :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. (Array dsts, Array tmpdsts)) -> Array srcs -. Array dsts +allocScratch i a f arr = + let + !(dst, tmp) = f arr (makeArray i a) + in case free tmp of !() -> dst + + -------------------------------------------------------------------------------- -- Parallel operations -------------------------------------------------------------------------------- diff --git a/src/DpsMergeSort4.hs b/src/DpsMergeSort4.hs index 129d49f..375f6ff 100644 --- a/src/DpsMergeSort4.hs +++ b/src/DpsMergeSort4.hs @@ -77,11 +77,14 @@ msortInplace src tmp = go src tmp where msort' :: (Show a, HasPrimOrd a) => a -> A.Array a -. A.Array a msort' anyVal src = let !(Ur len, src') = A.size2 src - !(src'', _tmp) = msortInplace src' (A.make len anyVal) in - case A.free _tmp of !() -> src'' +-- The old implementation with unsafe operation make. In case we want to look into performance. +-- !(src'', _tmp) = msortInplace src' (A.make len anyVal) in +-- case A.free _tmp of !() -> src'' + !src'' = A.allocScratch len anyVal msortInplace src' in + src'' {-# INLINE msort' #-} --- finally, the top-level merge sort function -- TODO: use A.get2/A.size2 for linearity +-- finally, the top-level merge sort function {-@ msort :: { xs:(A.Array a) | left xs == 0 && right xs == size xs } -> { ys:_ | toBag xs == toBag ys && isSorted' ys && A.size xs == A.size ys && token xs == token ys } @-}