Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
/docs/_build
*.hi
*.o

hie.yaml
1 change: 1 addition & 0 deletions accelerate.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ library
Data.Array.Accelerate.Classes.RealFloat
Data.Array.Accelerate.Classes.RealFrac
Data.Array.Accelerate.Classes.ToFloating
Data.Array.Accelerate.Classes.Vector
Data.Array.Accelerate.Debug.Internal.Clock
Data.Array.Accelerate.Debug.Internal.Flags
Data.Array.Accelerate.Debug.Internal.Graph
Expand Down
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,17 @@ module Data.Array.Accelerate (

-- ** SIMD vectors
Vec, VecElt,
Vectoring(..),
vecOfList,
listOfVec,

-- ** Type classes
-- *** Basic type classes
Eq(..),
Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_,
Enum, succ, pred,
Bounded, minBound, maxBound,

-- Functor(..), (<$>), ($>), void,
-- Monad(..),

Expand Down Expand Up @@ -445,6 +449,7 @@ import Data.Array.Accelerate.Classes.Rational
import Data.Array.Accelerate.Classes.RealFloat
import Data.Array.Accelerate.Classes.RealFrac
import Data.Array.Accelerate.Classes.ToFloating
import Data.Array.Accelerate.Classes.Vector
import Data.Array.Accelerate.Data.Either
import Data.Array.Accelerate.Data.Maybe
import Data.Array.Accelerate.Language
Expand Down
47 changes: 34 additions & 13 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Primitive.Types
import Control.DeepSeq
import Data.Kind
import Data.Maybe
Expand Down Expand Up @@ -655,7 +656,6 @@ data PrimConst ty where
-- constant from Floating
PrimPi :: FloatingType a -> PrimConst a


-- |Primitive scalar operations
--
data PrimFun sig where
Expand Down Expand Up @@ -748,6 +748,10 @@ data PrimFun sig where
PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLNot :: PrimFun (PrimBool -> PrimBool)

-- local array operators
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b)
Expand Down Expand Up @@ -825,7 +829,7 @@ expType = \case
While _ (Lam lhs _) _ -> lhsToTupR lhs
While{} -> error "What's the matter, you're running in the shadows"
Const tR _ -> TupRsingle tR
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index (Var repr _) _ -> arrayRtype repr
LinearIndex (Var repr _) _ -> arrayRtype repr
Expand All @@ -834,17 +838,20 @@ expType = \case
Undef tR -> TupRsingle tR
Coerce _ tR _ -> TupRsingle tR

primConstType :: PrimConst a -> SingleType a
primConstType :: PrimConst a -> ScalarType a
primConstType = \case
PrimMinBound t -> bounded t
PrimMaxBound t -> bounded t
PrimPi t -> floating t
where
bounded :: BoundedType a -> SingleType a
bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t
bounded :: BoundedType a -> ScalarType a
bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t

floating :: FloatingType t -> ScalarType t
floating = SingleScalarType . NumSingleType . FloatingNumType

floating :: FloatingType t -> SingleType t
floating = NumSingleType . FloatingNumType
vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a)
vector = VectorScalarType

primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
Expand Down Expand Up @@ -924,6 +931,17 @@ primFunType = \case
PrimLOr -> binary' tbool
PrimLNot -> unary' tbool

-- Local Vector operations
PrimVectorIndex v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` i, single a)

PrimVectorWrite v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` (i `TupRpair` single a), v)

-- general conversion between types
PrimFromIntegral a b -> unary (integral a) (num b)
PrimToFloating a b -> unary (num a) (floating b)
Expand All @@ -936,6 +954,7 @@ primFunType = \case
compare' a = binary (single a) tbool

single = TupRsingle . SingleScalarType
singleVector = TupRsingle . VectorScalarType
num = TupRsingle . SingleScalarType . NumSingleType
integral = num . IntegralNumType
floating = num . FloatingNumType
Expand Down Expand Up @@ -1100,9 +1119,9 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf =
rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b

rnfPrimConst :: PrimConst c -> ()
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t

rnfPrimFun :: PrimFun f -> ()
rnfPrimFun (PrimAdd t) = rnfNumType t
Expand Down Expand Up @@ -1165,6 +1184,7 @@ rnfPrimFun (PrimMin t) = rnfSingleType t
rnfPrimFun PrimLAnd = ()
rnfPrimFun PrimLOr = ()
rnfPrimFun PrimLNot = ()
rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i
rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n
rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f

Expand Down Expand Up @@ -1326,9 +1346,9 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||]
liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||]

liftPrimConst :: PrimConst c -> CodeQ (PrimConst c)
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]

liftPrimFun :: PrimFun f -> CodeQ (PrimFun f)
liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||]
Expand Down Expand Up @@ -1391,6 +1411,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||]
liftPrimFun PrimLAnd = [|| PrimLAnd ||]
liftPrimFun PrimLOr = [|| PrimLOr ||]
liftPrimFun PrimLNot = [|| PrimLNot ||]
liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||]
liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||]
liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||]

Expand Down
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq")
encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a
encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a
encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a
encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b
encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b
encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd")
Expand Down
3 changes: 1 addition & 2 deletions src/Data/Array/Accelerate/Classes/Enum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x
, ""
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , ""
, "These Prelude.Enum instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]
Expand Down
33 changes: 33 additions & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module : Data.Array.Accelerate.Classes.Vector
-- Copyright : [2016..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Classes.Vector where

import Data.Kind
import GHC.TypeLits
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Smart
import Data.Primitive.Vec

instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
vecWrite = mkVectorWrite
vecEmpty = undef


12 changes: 12 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Elt as Sugar
import qualified Data.Array.Accelerate.Trafo.Delayed as AST

import GHC.TypeLits
import Control.DeepSeq
import Control.Exception
import Control.Monad
Expand Down Expand Up @@ -1144,6 +1145,8 @@ evalPrim (PrimMin ty) = evalMin ty
evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim (PrimVectorIndex v i) = evalVectorIndex v i
evalPrim (PrimVectorWrite v i) = evalVectorWrite v i
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb
evalPrim (PrimToFloating ta tb) = evalToFloating ta tb

Expand All @@ -1168,6 +1171,12 @@ evalLOr (x, y) = fromBool (toBool x || toBool y)
evalLNot :: PrimBool -> PrimBool
evalLNot = fromBool . not . toBool

evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a
evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i)

evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a
evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
| IntegralDict <- integralDict ta
Expand Down Expand Up @@ -1213,6 +1222,9 @@ evalMaxBound (IntegralBoundedType ty)
evalPi :: FloatingType a -> a
evalPi ty | FloatingDict <- floatingDict ty = pi

evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a
evalVectorCreate (VectorType n _) = vecEmpty

evalSin :: FloatingType a -> (a -> a)
evalSin ty | FloatingDict <- floatingDict ty = sin

Expand Down
25 changes: 23 additions & 2 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
Expand All @@ -12,6 +12,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Smart
Expand Down Expand Up @@ -71,6 +72,10 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for type coercion functions
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),

-- ** Smart constructors for vector operations
mkVectorIndex,
mkVectorWrite,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
Expand All @@ -83,6 +88,7 @@ module Data.Array.Accelerate.Smart (
) where


import Data.Proxy
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
Expand All @@ -95,6 +101,7 @@ import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
Expand Down Expand Up @@ -859,7 +866,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
While t _ _ _ -> t
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
LinearIndex tp _ _ -> tp
Expand Down Expand Up @@ -1172,6 +1179,17 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a

-- Operators from Vec
mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a
mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down Expand Up @@ -1259,6 +1277,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)

mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d
mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c)))

mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary

Expand Down
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName )
import Data.Array.Accelerate.Trafo.Environment
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Vector

import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats

import Data.Bits
import Data.Monoid
import Data.Text ( Text )
import Data.Primitive.Vec
import Data.Text.Prettyprint.Doc
import Data.Text.Prettyprint.Doc.Render.Text
import GHC.Float ( float2Double, double2Float )
Expand Down Expand Up @@ -142,6 +144,8 @@ evalPrimApp env f x
PrimNEq ty -> evalNEq ty x env
PrimMax ty -> evalMax ty x env
PrimMin ty -> evalMin ty x env
PrimVectorIndex _ _ -> Nothing
PrimVectorWrite _ _ -> Nothing
PrimLAnd -> evalLAnd x env
PrimLOr -> evalLOr x env
PrimLNot -> evalLNot x env
Expand Down
Loading