diff --git a/accelerate-fft.cabal b/accelerate-fft.cabal index e445fcd..19ae2ae 100644 --- a/accelerate-fft.cabal +++ b/accelerate-fft.cabal @@ -85,6 +85,7 @@ library accelerate-llvm >= 1.3 , accelerate-llvm-ptx >= 1.3 , containers >= 0.5 + , exceptions >= 0.10 , hashable >= 1.0 , unordered-containers >= 0.2 , cuda >= 0.5 diff --git a/src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs b/src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs index a845b41..c4f3085 100644 --- a/src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs +++ b/src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternGuards #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} diff --git a/src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs b/src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs index ca6ea96..b9c2dfc 100644 --- a/src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs +++ b/src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs @@ -1,6 +1,8 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Math.FFT.LLVM.Native.Ix -- Copyright : [2017..2020] The Accelerate Team diff --git a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs index d6a0038..6d7594d 100644 --- a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs +++ b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs @@ -1,6 +1,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternGuards #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} @@ -92,14 +93,14 @@ fft' plans mode shR eR = aout <- allocateRemote aR sh stream <- asks ptxStream future <- new - liftPar $ - withArray eR ain stream $ \d_in -> do - withArray eR aout stream $ \d_out -> do - withPlan plans (sh,t) $ \h -> do - liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out) - -- - put future aout - return future + withPlan plans (sh,t) $ \h -> do + liftPar $ + withArray eR ain stream $ \d_in -> do + withArray eR aout stream $ \d_out -> do + liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out) + -- + put future aout + return future in case eR of NumericRfloat32 -> go (ArrayR shR (eltR @(Complex Float))) diff --git a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs index 90bce04..a06c39f 100644 --- a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs +++ b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs @@ -16,6 +16,10 @@ module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base where +import Control.Concurrent.MVar +import Control.Exception (evaluate) +import Control.Monad.Catch +import Control.Monad.IO.Class import Data.Array.Accelerate.Math.FFT.Type import Data.Array.Accelerate.Array.Data @@ -57,9 +61,17 @@ withArrayData NumericRfloat64 ad s k = return (Just e, r) {-# INLINE withLifetime' #-} -withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b +withLifetime' :: MonadIO m => Lifetime a -> (a -> m b) -> m b withLifetime' l k = do r <- k (unsafeGetValue l) liftIO $ touchLifetime l return r +{-# INLINE modifyMVar' #-} +modifyMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m (a,b)) -> m b +modifyMVar' m io = + mask $ \restore -> do + a <- liftIO (takeMVar m) + (a',b) <- restore (io a >>= liftIO . evaluate) `onException` liftIO (putMVar m a) + liftIO (putMVar m a') + return b diff --git a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs index bb4654c..5746519 100644 --- a/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs +++ b/src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs @@ -1,5 +1,7 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} -- | -- Module : Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans -- Copyright : [2017..2020] The Accelerate Team @@ -19,26 +21,32 @@ module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans ( ) where import Data.Array.Accelerate.Lifetime -import Data.Array.Accelerate.LLVM.PTX +import Data.Array.Accelerate.LLVM.PTX hiding (stream, poll) import Data.Array.Accelerate.LLVM.PTX.Foreign import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base import Control.Concurrent.MVar +import Control.Monad.Catch import Control.Monad.State -import Data.HashMap.Strict +import Data.HashMap.Strict hiding (map, update) import qualified Data.HashMap.Strict as Map import qualified Foreign.CUDA.Driver.Context as CUDA +import qualified Foreign.CUDA.Driver.Stream as CUDA import qualified Foreign.CUDA.FFT as FFT import GHC.Ptr import GHC.Base -import Prelude hiding ( lookup ) +import Prelude hiding ( lookup, mapM ) +import Data.Maybe +import Control.Arrow (second) +import Data.Function ((&)) +import Control.Monad.Reader (asks) data Plans a = Plans - { plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) (Lifetime FFT.Handle))) + { plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) [(Lifetime FFT.Handle, Maybe (Par PTX Bool, CUDA.Stream))])) , create :: a -> IO FFT.Handle , hash :: a -> Int } @@ -62,30 +70,57 @@ createPlan via mix = -- -- -- +-- TODO: Determine if this handle is used in the same stream. {-# INLINE withPlan #-} -withPlan :: Plans a -> a -> (FFT.Handle -> LLVM PTX b) -> LLVM PTX b +withPlan :: Plans a -> a -> (FFT.Handle -> Par PTX (Future b)) -> Par PTX (Future b) withPlan Plans{..} a k = do lc <- gets (deviceContext . ptxContext) - h <- liftIO $ - withLifetime lc $ \ctx -> - modifyMVar plans $ \pm -> - let key = (toKey ctx, hash a) in - case Map.lookup key pm of - -- handle does not exist yet; create it and add to the global - -- state for reuse - Nothing -> do - h <- create a - l <- newLifetime h - addFinalizer lc $ modifyMVar plans (\pm' -> return (Map.delete key pm', ())) - addFinalizer l $ FFT.destroy h - return ( Map.insert key l pm, l ) - - -- return existing handle - Just h -> return (pm, h) - -- - withLifetime' h k + ls <- asks ptxStream + withLifetime' ls $ \stream -> + withLifetime' lc $ \ctx -> do + let key = (toKey ctx, hash a) + -- Extract an existing cuFFT plan handle from our plan cache that isn't busy, + -- if one cannot be found, create a new cuFFT handle. + h <- modifyMVar' plans $ \pm -> do + let maybeHandles = pm !? key + handles = fromMaybe [] maybeHandles + + update Nothing = pure Nothing + update orig@(Just (isReady, _)) = isReady >>= \case + True -> pure Nothing + False -> pure orig + + updatedHandles <- zip (map fst handles) <$> mapM (update . snd) handles + + -- Extract first handle which is either entirely ready or is used but within the same stream + let extractFirstReady [] = (Nothing, []) + extractFirstReady (x@(_, Nothing):xs) = (Just x, xs) + extractFirstReady (x@(_, Just (_, s)):xs) | stream == s = (Just x, xs) + extractFirstReady (x@(_, Just _):xs) = second (x:) $ extractFirstReady xs + + (maybeReadyHandle, otherHandles) = extractFirstReady updatedHandles + + newHandle = liftIO $ do + h <- create a + l <- newLifetime h + addFinalizer l $ FFT.destroy h + when (isNothing maybeHandles) $ + addFinalizer lc $ modifyMVar_ plans $ pure . Map.delete key + pure l + + maybeReadyHandle & maybe newHandle (pure . fst) + & fmap (Map.insert key otherHandles pm,) + -- Ensure the handle is always returned back to the plan cache + let returnHandle = liftIO $ modifyMVar_ plans $ pure . Map.adjust ((h, Nothing):) key + flip onException returnHandle $ do + -- Invoke user-provided function with cuFFT handle + future <- withLifetime' h k + -- Push new cuFFT plan-handle onto list of plan-handles of equal settings, + -- w/ callback to check if the cuFFT handle is ready to use again. + planHandleEntry <- (h,) . Just . (,stream) . fmap isJust . poll <$> statusHandle future + liftIO $ modifyMVar_ plans $ pure . Map.adjust (planHandleEntry:) key + pure future {-# INLINE toKey #-} toKey :: CUDA.Context -> Int toKey (CUDA.Context (Ptr addr#)) = I# (addr2Int# addr#) - diff --git a/src/Data/Array/Accelerate/Math/FFT/Type.hs b/src/Data/Array/Accelerate/Math/FFT/Type.hs index 11415cd..dbaeb6a 100644 --- a/src/Data/Array/Accelerate/Math/FFT/Type.hs +++ b/src/Data/Array/Accelerate/Math/FFT/Type.hs @@ -2,6 +2,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE TypeOperators #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Math.FFT.Type