11{-# LANGUAGE CPP #-}
22{-# LANGUAGE BangPatterns #-}
33{-# LANGUAGE DeriveFunctor #-}
4+ {-# LANGUAGE LiberalTypeSynonyms #-}
45
56-- {-# LANGUAGE Strict #-}
67
@@ -15,6 +16,9 @@ module Array
1516
1617 -- * Construction and querying
1718 , alloc, make, generate, generate_par, generate_par_m, makeArray
19+ , flattenCallback, makeCallback, biJoinAllocAffine, allocScratchAffine
20+ , biJoinAlloc, allocScratch
21+
1822 , copy, copy_par, copy_par_m
1923 , size, get, set, slice, append
2024 , splitAt
@@ -95,9 +99,47 @@ makeArray = make
9599# endif
96100
97101{-# INLINE free #-}
98- free :: HasPrim a => Array a -. ()
102+ free :: Array a -. ()
99103free = Unsafe. toLinear (\ _ -> () )
100104
105+ {-# INLINE flattenCallback #-}
106+ flattenCallback :: (forall c . (Array b -. Ur c ) -. Array a -. Ur c ) -. Array a -. Array b
107+ flattenCallback f arr = unur (f ur arr)
108+
109+ {-# INLINE makeCallback #-}
110+ makeCallback :: (Array b -. Array a ) -. (Array a -. Ur c ) -. Array b -. Ur c
111+ makeCallback direct k arr = k (direct arr)
112+
113+ {-# INLINE biJoinAllocAffine #-}
114+ biJoinAllocAffine :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. Array dsts ) -> Array srcs -. Array dsts
115+ biJoinAllocAffine i a f = flattenCallback (\ cont src -> alloc i a (\ tmp -> makeCallback (f tmp) cont src))
116+
117+ -- efficient implementation of above
118+ {-# INLINE allocScratchAffine #-}
119+ allocScratchAffine :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. Array dsts ) -> Array srcs -. Array dsts
120+ allocScratchAffine i a f arr = f arr (makeArray i a)
121+
122+ {-# INLINE biJoinAlloc #-}
123+ biJoinAlloc :: HasPrim tmps => Int -> tmps -> (Array tmps -. Array srcs -. (Array dsts , Array tmpdsts )) -> Array srcs -. Array dsts
124+ biJoinAlloc i a f =
125+ let
126+ g tmp src =
127+ let
128+ ! (dst , tmp' ) = f tmp src
129+ in
130+ case free tmp' of ! () -> dst
131+ in
132+ flattenCallback (\ cont src -> alloc i a (\ tmp -> makeCallback (g tmp) cont src))
133+
134+ -- efficient implementation of above
135+ {-# INLINE allocScratch #-}
136+ allocScratch :: HasPrim tmps => Int -> tmps -> (Array srcs -. Array tmps -. (Array dsts , Array tmpdsts )) -> Array srcs -. Array dsts
137+ allocScratch i a f arr =
138+ let
139+ ! (dst , tmp ) = f arr (makeArray i a )
140+ in case free tmp of ! () -> dst
141+
142+
101143--------------------------------------------------------------------------------
102144-- Parallel operations
103145--------------------------------------------------------------------------------
0 commit comments