Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions accelerate-fft.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ library
accelerate-llvm >= 1.3
, accelerate-llvm-ptx >= 1.3
, containers >= 0.5
, exceptions >= 0.10
, hashable >= 1.0
, unordered-containers >= 0.2
, cuda >= 0.5
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
Expand Down
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- |
-- Module : Data.Array.Accelerate.Math.FFT.LLVM.Native.Ix
-- Copyright : [2017..2020] The Accelerate Team
Expand Down
17 changes: 9 additions & 8 deletions src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
Expand Down Expand Up @@ -92,14 +93,14 @@ fft' plans mode shR eR =
aout <- allocateRemote aR sh
stream <- asks ptxStream
future <- new
liftPar $
withArray eR ain stream $ \d_in -> do
withArray eR aout stream $ \d_out -> do
withPlan plans (sh,t) $ \h -> do
liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out)
--
put future aout
return future
withPlan plans (sh,t) $ \h -> do
liftPar $
withArray eR ain stream $ \d_in -> do
withArray eR aout stream $ \d_out -> do
liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out)
--
put future aout
return future
in
case eR of
NumericRfloat32 -> go (ArrayR shR (eltR @(Complex Float)))
Expand Down
14 changes: 13 additions & 1 deletion src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
where

import Control.Concurrent.MVar
import Control.Exception (evaluate)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Array.Accelerate.Math.FFT.Type

import Data.Array.Accelerate.Array.Data
Expand Down Expand Up @@ -57,9 +61,17 @@ withArrayData NumericRfloat64 ad s k =
return (Just e, r)

{-# INLINE withLifetime' #-}
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' :: MonadIO m => Lifetime a -> (a -> m b) -> m b
withLifetime' l k = do
r <- k (unsafeGetValue l)
liftIO $ touchLifetime l
return r

{-# INLINE modifyMVar' #-}
modifyMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m (a,b)) -> m b
modifyMVar' m io =
mask $ \restore -> do
a <- liftIO (takeMVar m)
(a',b) <- restore (io a >>= liftIO . evaluate) `onException` liftIO (putMVar m a)
liftIO (putMVar m a')
return b
83 changes: 59 additions & 24 deletions src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
-- |
-- Module : Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans
-- Copyright : [2017..2020] The Accelerate Team
Expand All @@ -19,26 +21,32 @@ module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans (
) where

import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX hiding (stream, poll)
import Data.Array.Accelerate.LLVM.PTX.Foreign

import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base

import Control.Concurrent.MVar
import Control.Monad.Catch
import Control.Monad.State
import Data.HashMap.Strict
import Data.HashMap.Strict hiding (map, update)
import qualified Data.HashMap.Strict as Map

import qualified Foreign.CUDA.Driver.Context as CUDA
import qualified Foreign.CUDA.Driver.Stream as CUDA
import qualified Foreign.CUDA.FFT as FFT

import GHC.Ptr
import GHC.Base
import Prelude hiding ( lookup )
import Prelude hiding ( lookup, mapM )
import Data.Maybe
import Control.Arrow (second)
import Data.Function ((&))
import Control.Monad.Reader (asks)


data Plans a = Plans
{ plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) (Lifetime FFT.Handle)))
{ plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) [(Lifetime FFT.Handle, Maybe (Par PTX Bool, CUDA.Stream))]))
, create :: a -> IO FFT.Handle
, hash :: a -> Int
}
Expand All @@ -62,30 +70,57 @@ createPlan via mix =
--
-- <http://docs.nvidia.com/cuda/cufft/index.html#thread-safety>
--
-- TODO: Determine if this handle is used in the same stream.
{-# INLINE withPlan #-}
withPlan :: Plans a -> a -> (FFT.Handle -> LLVM PTX b) -> LLVM PTX b
withPlan :: Plans a -> a -> (FFT.Handle -> Par PTX (Future b)) -> Par PTX (Future b)
withPlan Plans{..} a k = do
lc <- gets (deviceContext . ptxContext)
h <- liftIO $
withLifetime lc $ \ctx ->
modifyMVar plans $ \pm ->
let key = (toKey ctx, hash a) in
case Map.lookup key pm of
-- handle does not exist yet; create it and add to the global
-- state for reuse
Nothing -> do
h <- create a
l <- newLifetime h
addFinalizer lc $ modifyMVar plans (\pm' -> return (Map.delete key pm', ()))
addFinalizer l $ FFT.destroy h
return ( Map.insert key l pm, l )

-- return existing handle
Just h -> return (pm, h)
--
withLifetime' h k
ls <- asks ptxStream
withLifetime' ls $ \stream ->
withLifetime' lc $ \ctx -> do
let key = (toKey ctx, hash a)
-- Extract an existing cuFFT plan handle from our plan cache that isn't busy,
-- if one cannot be found, create a new cuFFT handle.
h <- modifyMVar' plans $ \pm -> do
let maybeHandles = pm !? key
handles = fromMaybe [] maybeHandles

update Nothing = pure Nothing
update orig@(Just (isReady, _)) = isReady >>= \case
True -> pure Nothing
False -> pure orig

updatedHandles <- zip (map fst handles) <$> mapM (update . snd) handles

-- Extract first handle which is either entirely ready or is used but within the same stream
let extractFirstReady [] = (Nothing, [])
extractFirstReady (x@(_, Nothing):xs) = (Just x, xs)
extractFirstReady (x@(_, Just (_, s)):xs) | stream == s = (Just x, xs)
extractFirstReady (x@(_, Just _):xs) = second (x:) $ extractFirstReady xs

(maybeReadyHandle, otherHandles) = extractFirstReady updatedHandles

newHandle = liftIO $ do
h <- create a
l <- newLifetime h
addFinalizer l $ FFT.destroy h
when (isNothing maybeHandles) $
addFinalizer lc $ modifyMVar_ plans $ pure . Map.delete key
pure l

maybeReadyHandle & maybe newHandle (pure . fst)
& fmap (Map.insert key otherHandles pm,)
-- Ensure the handle is always returned back to the plan cache
let returnHandle = liftIO $ modifyMVar_ plans $ pure . Map.adjust ((h, Nothing):) key
flip onException returnHandle $ do
-- Invoke user-provided function with cuFFT handle
future <- withLifetime' h k
-- Push new cuFFT plan-handle onto list of plan-handles of equal settings,
-- w/ callback to check if the cuFFT handle is ready to use again.
planHandleEntry <- (h,) . Just . (,stream) . fmap isJust . poll <$> statusHandle future
liftIO $ modifyMVar_ plans $ pure . Map.adjust (planHandleEntry:) key
pure future

{-# INLINE toKey #-}
toKey :: CUDA.Context -> Int
toKey (CUDA.Context (Ptr addr#)) = I# (addr2Int# addr#)

1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Math/FFT/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Math.FFT.Type
Expand Down
Loading