@@ -28,7 +28,7 @@ import QueryType
2828import Types.Core
2929import Types.OpNames qualified as P
3030import Types.Primitives
31- import Util (allM )
31+ import Util (allM , zipWithZ )
3232
3333-- === Vectorization ===
3434
@@ -152,6 +152,16 @@ vectorizeLoopsDecls nest cont =
152152 extendSubst (b @> atomVarName v) $
153153 vectorizeLoopsDecls rest cont
154154
155+ vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o )
156+ vectorizeLoopsLamExpr (LamExpr bs body) = case bs of
157+ Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body)
158+ Nest (b:> ty) rest -> do
159+ ty' <- renameM ty
160+ withFreshBinder (getNameHint b) ty' \ b' -> do
161+ extendRenamer (b @> binderName b') do
162+ LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body
163+ return $ LamExpr (Nest b' bs') body'
164+
155165vectorizeLoopsExpr :: (Emits o ) => SExpr i -> TopVectorizeM i o (SExpr o )
156166vectorizeLoopsExpr expr = do
157167 vectorByteWidth <- askVectorByteWidth
@@ -175,7 +185,8 @@ vectorizeLoopsExpr expr = do
175185 ctx = mempty { messageCtx = [msg] }
176186 errs' = prependCtxToErrs ctx errs
177187 modify (<> LiftE errs')
178- renameM expr
188+ recurSeq expr
189+ PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
179190 PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
180191 item' <- renameM item
181192 itemTy <- return $ getType item'
@@ -197,6 +208,15 @@ vectorizeLoopsExpr expr = do
197208 vectorizeLoopsBlock body
198209 PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam)
199210 _ -> renameM expr
211+ where
212+ recurSeq :: (Emits o ) => SExpr i -> TopVectorizeM i o (SExpr o )
213+ recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do
214+ effs' <- renameM effs
215+ ixty' <- renameM ixty
216+ dest' <- renameM dest
217+ body' <- vectorizeLoopsLamExpr body
218+ return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
219+ recurSeq _ = error " Impossible"
200220
201221-- Really we should check this by seeing whether there is an instance for a
202222-- `Commutative` class, or something like that, but for now just pattern-match
@@ -331,7 +351,8 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of
331351 (VRename v) -> Var <$> toAtomVar v)
332352 (Nest (b:> ty) rest, (stab: stabs)) -> do
333353 ty' <- vectorizeType ty
334- withFreshBinder (getNameHint b) ty' \ b' -> do
354+ ty'' <- promoteTypeByStability ty' stab
355+ withFreshBinder (getNameHint b) ty'' \ b' -> do
335356 var <- toAtomVar $ binderName b'
336357 extendSubst (b @> VVal stab (Var var)) do
337358 LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs
@@ -396,14 +417,16 @@ vectorizeRefOp ref' op =
396417 VVal xStab x <- vectorizeAtom x'
397418 basemonoid <- case refStab of
398419 Uniform -> case xStab of
399- Uniform -> vectorizeBaseMonoid basemonoid' Uniform Uniform
420+ Uniform -> do
421+ vectorizeBaseMonoid basemonoid' Uniform Uniform
400422 -- This case represents accumulating something loop-varying into a
401423 -- loop-invariant accumulator, as e.g. sum. We can implement that for
402424 -- commutative monoids, but we would want to have started with private
403425 -- accumulators (one per lane), and then reduce them with an
404426 -- appropriate sequence of vector reduction intrinsics at the end.
405427 _ -> throwVectErr $ " Vectorizing non-sliced accumulation not implemented"
406- Contiguous -> vectorizeBaseMonoid basemonoid' Varying xStab
428+ Contiguous -> do
429+ vectorizeBaseMonoid basemonoid' Varying xStab
407430 s -> throwVectErr $ " Cannot vectorize reference with loop-varying stability " ++ show s
408431 VVal Uniform <$> emitOp (RefOp ref $ MExtend basemonoid x)
409432 IndexRef _ i' -> do
@@ -543,6 +566,15 @@ ensureVarying (VRename v) = do
543566 x <- Var <$> toAtomVar v
544567 ensureVarying (VVal Uniform x)
545568
569+ promoteTypeByStability :: SType o -> Stability -> VectorizeM i o (SType o )
570+ promoteTypeByStability ty = \ case
571+ Uniform -> return ty
572+ Contiguous -> return ty
573+ Varying -> getVectorType ty
574+ ProdStability stabs -> case ty of
575+ ProdTy elts -> ProdTy <$> zipWithZ promoteTypeByStability elts stabs
576+ _ -> throw ZipErr " Type and stability"
577+
546578-- === computing byte widths ===
547579
548580newtype CalcWidthM i o a = CalcWidthM {
0 commit comments