Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
/docs/_build
*.hi
*.o

hie.yaml
3 changes: 3 additions & 0 deletions src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,15 @@ module Data.Array.Accelerate (

-- ** SIMD vectors
Vec, VecElt,
mkVec,

-- ** Type classes
-- *** Basic type classes
Eq(..),
Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_,
Enum, succ, pred,
Bounded, minBound, maxBound,
Vectoring(..),
-- Functor(..), (<$>), ($>), void,
-- Monad(..),

Expand Down Expand Up @@ -445,6 +447,7 @@ import Data.Array.Accelerate.Classes.Rational
import Data.Array.Accelerate.Classes.RealFloat
import Data.Array.Accelerate.Classes.RealFrac
import Data.Array.Accelerate.Classes.ToFloating
import Data.Array.Accelerate.Classes.Vector
import Data.Array.Accelerate.Data.Either
import Data.Array.Accelerate.Data.Maybe
import Data.Array.Accelerate.Language
Expand Down
12 changes: 12 additions & 0 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,9 @@ data PrimFun sig where
PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLNot :: PrimFun (PrimBool -> PrimBool)

-- local array operators
PrimVectorIndex :: KnownNat n => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b)
Expand Down Expand Up @@ -924,6 +927,12 @@ primFunType = \case
PrimLOr -> binary' tbool
PrimLNot -> unary' tbool

-- Local Vector operations
PrimVectorIndex v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` i, single a)

-- general conversion between types
PrimFromIntegral a b -> unary (integral a) (num b)
PrimToFloating a b -> unary (num a) (floating b)
Expand All @@ -936,6 +945,7 @@ primFunType = \case
compare' a = binary (single a) tbool

single = TupRsingle . SingleScalarType
singleVector = TupRsingle . VectorScalarType
num = TupRsingle . SingleScalarType . NumSingleType
integral = num . IntegralNumType
floating = num . FloatingNumType
Expand Down Expand Up @@ -1165,6 +1175,7 @@ rnfPrimFun (PrimMin t) = rnfSingleType t
rnfPrimFun PrimLAnd = ()
rnfPrimFun PrimLOr = ()
rnfPrimFun PrimLNot = ()
rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i
rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n
rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f

Expand Down Expand Up @@ -1391,6 +1402,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||]
liftPrimFun PrimLAnd = [|| PrimLAnd ||]
liftPrimFun PrimLOr = [|| PrimLOr ||]
liftPrimFun PrimLNot = [|| PrimLNot ||]
liftPrimFun (PrimVectorIndex v i) = [||PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||]
liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||]
liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||]

Expand Down
3 changes: 1 addition & 2 deletions src/Data/Array/Accelerate/Classes/Enum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x
, ""
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , ""
, "These Prelude.Enum instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]
Expand Down
31 changes: 31 additions & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module : Data.Array.Accelerate.Classes.Vector
-- Copyright : [2016..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Classes.Vector where

import GHC.TypeLits
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Smart
import Data.Primitive.Vec

class Vectoring a b c | a -> b where
indexAt :: a -> c -> b

instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) (Exp Int) where
indexAt = mkVectorIndex


12 changes: 12 additions & 0 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Smart
Expand Down Expand Up @@ -71,6 +72,9 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for type coercion functions
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),

-- ** Smart constructors for vector operations
mkVectorIndex,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
Expand All @@ -83,6 +87,7 @@ module Data.Array.Accelerate.Smart (
) where


import Data.Proxy
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
Expand All @@ -95,6 +100,7 @@ import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
Expand Down Expand Up @@ -1172,6 +1178,12 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a

-- Operators from Vec
mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a
mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down
14 changes: 14 additions & 0 deletions src/Data/Primitive/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Primitive.Vec
Expand All @@ -31,12 +32,16 @@ module Data.Primitive.Vec (
Vec8, pattern Vec8,
Vec16, pattern Vec16,

mkVec,

listOfVec,
liftVec,

) where

import Data.Proxy
import Control.Monad.ST
import Control.Monad.Reader
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Data.Text.Prettyprint.Doc
Expand Down Expand Up @@ -83,6 +88,14 @@ import GHC.Word
--
data Vec (n :: Nat) a = Vec ByteArray#

mkVec :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a
mkVec vs = runST $ do
let n :: Int = fromIntegral $ natVal $ Proxy @n
mba <- newByteArray (n * sizeOf (undefined :: a))
zipWithM_ (writeByteArray mba) [0..n] vs
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#

type role Vec nominal representational

instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where
Expand Down Expand Up @@ -259,6 +272,7 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#


-- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able
-- to do this without copying, but I don't think the definition of ByteArray# is
-- exported (or it is deeply magical).
Expand Down