@@ -49,9 +49,13 @@ import Util (allM, zipWithZ)
4949-- TODO: Local vector values? We might want to pack short and pure for loops into vectors,
5050-- to support things like float3 etc.
5151data Stability
52- = Uniform -- constant across vectorized dimension
53- | Varying -- varying across vectorized dimension
54- | Contiguous -- varying, but contiguous across vectorized dimension
52+ -- Constant across vectorized dimension, represented as a scalar
53+ = Uniform
54+ -- Varying across vectorized dimension, represented as a vector
55+ | Varying
56+ -- Varying, but contiguous across vectorized dimension; represented as a
57+ -- scalar carrying the first value
58+ | Contiguous
5559 | ProdStability [Stability ]
5660 deriving (Eq , Show )
5761
@@ -168,25 +172,27 @@ vectorizeLoopsExpr expr = do
168172 narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
169173 let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
170174 case expr of
171- PrimOp (DAMOp (Seq effs dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal n))) dest body))
172- | n `mod` loopWidth == 0 -> (do
173- safe <- vectorSafeEffect effs
174- if safe
175- then (do
176- Distinct <- getDistinct
177- let vn = n `div` loopWidth
178- body' <- vectorizeSeq loopWidth body
179- dest' <- renameM dest
180- seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
181- return $ PrimOp $ DAMOp seqOp)
182- else renameM expr)
183- `catchErr` \ errs -> do
184- let msg = " In `vectorizeLoopsDecls`:\n Expr:\n " ++ pprint expr
185- ctx = mempty { messageCtx = [msg] }
186- errs' = prependCtxToErrs ctx errs
187- modify (<> LiftE errs')
188- recurSeq expr
189- PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
175+ PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
176+ sz <- simplifyIxSize =<< renameM ixty
177+ case sz of
178+ Just n | n `mod` loopWidth == 0 -> (do
179+ safe <- vectorSafeEffect effs
180+ if safe
181+ then (do
182+ Distinct <- getDistinct
183+ let vn = n `div` loopWidth
184+ body' <- vectorizeSeq loopWidth ixty body
185+ dest' <- renameM dest
186+ seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
187+ return $ PrimOp $ DAMOp seqOp)
188+ else renameM expr)
189+ `catchErr` \ errs -> do
190+ let msg = " In `vectorizeLoopsDecls`:\n Expr:\n " ++ pprint expr
191+ ctx = mempty { messageCtx = [msg] }
192+ errs' = prependCtxToErrs ctx errs
193+ modify (<> LiftE errs')
194+ recurSeq expr
195+ _ -> recurSeq expr
190196 PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
191197 item' <- renameM item
192198 itemTy <- return $ getType item'
@@ -218,6 +224,15 @@ vectorizeLoopsExpr expr = do
218224 return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
219225 recurSeq _ = error " Impossible"
220226
227+ simplifyIxSize :: (EnvReader m , ScopableBuilder SimpIR m )
228+ => IxType SimpIR n -> m n (Maybe Word32 )
229+ simplifyIxSize ixty = do
230+ sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size []
231+ cheapReduce sizeMethod >>= \ case
232+ Just (IdxRepVal n) -> return $ Just n
233+ _ -> return Nothing
234+ {-# INLINE simplifyIxSize #-}
235+
221236-- Really we should check this by seeing whether there is an instance for a
222237-- `Commutative` class, or something like that, but for now just pattern-match
223238-- to detect scalar addition as the only monoid we recognize as commutative.
@@ -300,22 +315,27 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
300315 Nothing -> error $ " Handle " ++ pprint h ++ " not present in commute map?"
301316 safe _ = return False
302317
303- vectorizeSeq :: Word32 -> LamExpr SimpIR i -> TopVectorizeM i o ( LamExpr SimpIR o )
304- vectorizeSeq loopWidth ( UnaryLamExpr (b :> ty) body) = do
305- (_, ty') <- case ty of
306- ProdTy [ixTy, ref] -> do
307- ixTy' <- renameM ixTy
318+ vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i
319+ -> TopVectorizeM i o ( LamExpr SimpIR o )
320+ vectorizeSeq loopWidth ixty ( UnaryLamExpr (b :> ty) body) = do
321+ newLoopTy <- case ty of
322+ ProdTy [_ixType, ref] -> do
308323 ref' <- renameM ref
309- return (ixTy', ProdTy [IdxRepTy , ref'])
324+ return $ ProdTy [IdxRepTy , ref']
310325 _ -> error " Unexpected seq binder type"
326+ ixty' <- renameM ixty
311327 liftVectorizeM loopWidth $
312- buildUnaryLamExpr (getNameHint b) ty' \ ci -> do
313- -- XXX: we're assuming `Fin n` here
328+ buildUnaryLamExpr (getNameHint b) newLoopTy \ ci -> do
329+ -- The per-tile loop iterates on `Fin`
314330 (viOrd, dest) <- fromPair $ Var ci
315331 iOrd <- imul viOrd $ IdxRepVal loopWidth
316- extendSubst (b @> VVal (ProdStability [Contiguous , ProdStability [Uniform ]]) (PairVal iOrd dest)) $
332+ -- TODO: It would be nice to cancel this UnsafeFromOrdinal with the
333+ -- Ordinal that will be taken later when indexing, but that should
334+ -- probably be a separate pass.
335+ i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd]
336+ extendSubst (b @> VVal (ProdStability [Contiguous , ProdStability [Uniform ]]) (PairVal i dest)) $
317337 vectorizeBlock body $> UnitVal
318- vectorizeSeq _ _ = error " expected a unary lambda expression"
338+ vectorizeSeq _ _ _ = error " expected a unary lambda expression"
319339
320340newtype VectorizeM i o a =
321341 VectorizeM { runVectorizeM ::
@@ -467,9 +487,13 @@ vectorizePrimOp op = case op of
467487 BinOp opk arg1 arg2 -> do
468488 sx@ (VVal vx x) <- vectorizeAtom arg1
469489 sy@ (VVal vy y) <- vectorizeAtom arg2
470- let v = case (vx, vy) of (Uniform , Uniform ) -> Uniform ; _ -> Varying
471- x' <- if vx /= v then ensureVarying sx else return x
472- y' <- if vy /= v then ensureVarying sy else return y
490+ let v = case (opk, vx, vy) of
491+ (_, Uniform , Uniform ) -> Uniform
492+ (IAdd , Uniform , Contiguous ) -> Contiguous
493+ (IAdd , Contiguous , Uniform ) -> Contiguous
494+ _ -> Varying
495+ x' <- if v == Varying then ensureVarying sx else return x
496+ y' <- if v == Varying then ensureVarying sy else return y
473497 VVal v <$> emitOp (BinOp opk x' y')
474498 MiscOp (CastOp tyArg arg) -> do
475499 ty <- vectorizeType tyArg
0 commit comments