Skip to content

Commit e446273

Browse files
committed
Add vectorization under a loop, and another example even closer to the matmul kernel.
In the process, discovered and fixed two typing bugs that happened to cancel out on previous test cases: - Vectorizing a LamExpr may change the types of the arguments (if they are now vectors) - Vector-indexing returns an object of different type from the element type that was being indexed (namely, the vector of those), and vectorIndexRepVal in Imp needs to accommodate that.
1 parent 5fbe15a commit e446273

File tree

3 files changed

+75
-15
lines changed

3 files changed

+75
-15
lines changed

src/lib/Imp.hs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ toImpVectorOp = \case
427427
-- VectorIdx requires that tbl' have a scalar element type, which is
428428
-- ultimately enforced by `Lower.getVectorType` barfing on non-scalars.
429429
tbl <- atomToRepVal tbl'
430-
repValAtom =<< vectorIndexRepVal tbl i (toIVectorType vty)
430+
repValAtom =<< vectorIndexRepVal tbl i vty
431431
VectorSubref ref i vty -> do
432432
refDest <- atomToDest ref
433433
refi <- destToAtom <$> indexDest refDest i
@@ -1024,9 +1024,10 @@ naryIndexRepVal x (ix:ixs) = do
10241024

10251025
-- TODO: de-dup with indexDest?
10261026
indexRepValParam :: Emits n
1027-
=> RepVal SimpIR n -> SAtom n -> (IExpr n -> SubstImpM i n (IExpr n))
1028-
-> SubstImpM i n (RepVal SimpIR n)
1029-
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i func = do
1027+
=> SRepVal n -> SAtom n -> (SType n -> SType n)
1028+
-> (IExpr n -> SubstImpM i n (IExpr n))
1029+
-> SubstImpM i n (SRepVal n)
1030+
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do
10301031
eltTy' <- applySubst (b@>SubstVal i) eltTy
10311032
ord <- ordinalImp (IxType t d) i
10321033
leafTys <- typeToTree tabTy
@@ -1039,20 +1040,26 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i func =
10391040
case ixStruct of
10401041
EmptyAbs (Nest _ Empty) -> func ptr' >>= load
10411042
_ -> func ptr'
1042-
return $ RepVal eltTy' vals'
1043-
indexRepValParam _ _ _ = error "expected table type"
1043+
-- `func` may have changed the types of the `vals'`. The caller must also
1044+
-- supply `tyFunc` to reflect that change in the SType.
1045+
return $ RepVal (tyFunc eltTy') vals'
1046+
indexRepValParam _ _ _ _ = error "expected table type"
10441047
{-# INLINE indexRepValParam #-}
10451048

10461049
indexRepVal :: Emits n
10471050
=> RepVal SimpIR n -> SAtom n -> SubstImpM i n (RepVal SimpIR n)
1048-
indexRepVal rep i = indexRepValParam rep i return
1051+
indexRepVal rep i = indexRepValParam rep i id return
10491052
{-# INLINE indexRepVal #-}
10501053

10511054
vectorIndexRepVal :: Emits n
1052-
=> RepVal SimpIR n -> SAtom n -> IVectorType
1055+
=> RepVal SimpIR n -> SAtom n -> SType n
10531056
-> SubstImpM i n (RepVal SimpIR n)
1054-
vectorIndexRepVal rep i vty = indexRepValParam rep i action where
1055-
action ptr = castPtrToVectorType ptr vty
1057+
vectorIndexRepVal rep i vty =
1058+
-- Passing `const vty` here depends on knowing that `vectorIndexRepVal` is
1059+
-- only called on references of scalar base type, so that the give `vty` is,
1060+
-- actually, the type of the result of the indexing operation.
1061+
indexRepValParam rep i (const vty) action where
1062+
action ptr = castPtrToVectorType ptr (toIVectorType vty)
10561063
{-# INLINE vectorIndexRepVal #-}
10571064

10581065
projectDest :: Int -> Dest n -> Dest n

src/lib/Vectorize.hs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import QueryType
2828
import Types.Core
2929
import Types.OpNames qualified as P
3030
import Types.Primitives
31-
import Util (allM)
31+
import Util (allM, zipWithZ)
3232

3333
-- === Vectorization ===
3434

@@ -152,6 +152,16 @@ vectorizeLoopsDecls nest cont =
152152
extendSubst (b @> atomVarName v) $
153153
vectorizeLoopsDecls rest cont
154154

155+
vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
156+
vectorizeLoopsLamExpr (LamExpr bs body) = case bs of
157+
Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body)
158+
Nest (b:>ty) rest -> do
159+
ty' <- renameM ty
160+
withFreshBinder (getNameHint b) ty' \b' -> do
161+
extendRenamer (b @> binderName b') do
162+
LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body
163+
return $ LamExpr (Nest b' bs') body'
164+
155165
vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
156166
vectorizeLoopsExpr expr = do
157167
vectorByteWidth <- askVectorByteWidth
@@ -175,7 +185,8 @@ vectorizeLoopsExpr expr = do
175185
ctx = mempty { messageCtx = [msg] }
176186
errs' = prependCtxToErrs ctx errs
177187
modify (<> LiftE errs')
178-
renameM expr
188+
recurSeq expr
189+
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
179190
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
180191
item' <- renameM item
181192
itemTy <- return $ getType item'
@@ -197,6 +208,15 @@ vectorizeLoopsExpr expr = do
197208
vectorizeLoopsBlock body
198209
PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam)
199210
_ -> renameM expr
211+
where
212+
recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
213+
recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do
214+
effs' <- renameM effs
215+
ixty' <- renameM ixty
216+
dest' <- renameM dest
217+
body' <- vectorizeLoopsLamExpr body
218+
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
219+
recurSeq _ = error "Impossible"
200220

201221
-- Really we should check this by seeing whether there is an instance for a
202222
-- `Commutative` class, or something like that, but for now just pattern-match
@@ -331,7 +351,8 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of
331351
(VRename v) -> Var <$> toAtomVar v)
332352
(Nest (b:>ty) rest, (stab:stabs)) -> do
333353
ty' <- vectorizeType ty
334-
withFreshBinder (getNameHint b) ty' \b' -> do
354+
ty'' <- promoteTypeByStability ty' stab
355+
withFreshBinder (getNameHint b) ty'' \b' -> do
335356
var <- toAtomVar $ binderName b'
336357
extendSubst (b @> VVal stab (Var var)) do
337358
LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs
@@ -396,14 +417,16 @@ vectorizeRefOp ref' op =
396417
VVal xStab x <- vectorizeAtom x'
397418
basemonoid <- case refStab of
398419
Uniform -> case xStab of
399-
Uniform -> vectorizeBaseMonoid basemonoid' Uniform Uniform
420+
Uniform -> do
421+
vectorizeBaseMonoid basemonoid' Uniform Uniform
400422
-- This case represents accumulating something loop-varying into a
401423
-- loop-invariant accumulator, as e.g. sum. We can implement that for
402424
-- commutative monoids, but we would want to have started with private
403425
-- accumulators (one per lane), and then reduce them with an
404426
-- appropriate sequence of vector reduction intrinsics at the end.
405427
_ -> throwVectErr $ "Vectorizing non-sliced accumulation not implemented"
406-
Contiguous -> vectorizeBaseMonoid basemonoid' Varying xStab
428+
Contiguous -> do
429+
vectorizeBaseMonoid basemonoid' Varying xStab
407430
s -> throwVectErr $ "Cannot vectorize reference with loop-varying stability " ++ show s
408431
VVal Uniform <$> emitOp (RefOp ref $ MExtend basemonoid x)
409432
IndexRef _ i' -> do
@@ -543,6 +566,15 @@ ensureVarying (VRename v) = do
543566
x <- Var <$> toAtomVar v
544567
ensureVarying (VVal Uniform x)
545568

569+
promoteTypeByStability :: SType o -> Stability -> VectorizeM i o (SType o)
570+
promoteTypeByStability ty = \case
571+
Uniform -> return ty
572+
Contiguous -> return ty
573+
Varying -> getVectorType ty
574+
ProdStability stabs -> case ty of
575+
ProdTy elts -> ProdTy <$> zipWithZ promoteTypeByStability elts stabs
576+
_ -> throw ZipErr "Type and stability"
577+
546578
-- === computing byte widths ===
547579

548580
newtype CalcWidthM i o a = CalcWidthM {

tests/opt-tests.dx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,24 @@ _ = yield_accum (AddMonoid Int32) \ref.
190190
-- CHECK: [[xi:v#[0-9]+]]:<16xInt32> =
191191
-- CHECK-NEXT: vslice
192192
-- CHECK: extend [[refi]] [[xi]]
193+
194+
"vectorizing under an outer loop, like matmul"
195+
-- CHECK-LABEL: vectorizing under an outer loop, like matmul
196+
197+
mat1 = for i:(Fin 32). for j:(Fin 32).
198+
(n_to_i32 (ordinal i)) * (n_to_i32 (ordinal j)) + 1
199+
200+
mat2 = for i:(Fin 32). for j:(Fin 32).
201+
(n_to_i32 (ordinal i)) * (n_to_i32 (ordinal j)) + 7
202+
203+
%passes vect
204+
_ = yield_accum (AddMonoid Int32) \result.
205+
for k:(Fin 32).
206+
for j:(Fin 32).
207+
result!(3@(Fin 32))!j += mat1[3@_][k] * mat2[k][j]
208+
-- CHECK: seq (RawFin 0x2)
209+
-- CHECK: [[refj:v#[0-9]+]]:(Ref {{v#[0-9]+}} <16xInt32>) = vrefslice
210+
-- CHECK: [[mat2j:v#[0-9]+]]:<16xInt32> = vslice
211+
-- CHECK: [[mat1:v#[0-9]+]]:<16xInt32> = vbroadcast
212+
-- CHECK: [[prodj:v#[0-9]+]]:<16xInt32> = %imul [[mat1]] [[mat2j]]
213+
-- CHECK: extend [[refj]] [[prodj]]

0 commit comments

Comments
 (0)