Skip to content

Commit 1728824

Browse files
committed
Generalize the vectorizer to arbitrary user-defined index sets
provided it can prove the size is static and divides by the vector width.
1 parent 2580bc9 commit 1728824

File tree

1 file changed

+52
-32
lines changed

1 file changed

+52
-32
lines changed

src/lib/Vectorize.hs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ import Util (allM, zipWithZ)
4949
-- TODO: Local vector values? We might want to pack short and pure for loops into vectors,
5050
-- to support things like float3 etc.
5151
data Stability
52-
= Uniform -- constant across vectorized dimension
53-
| Varying -- varying across vectorized dimension
54-
| Contiguous -- varying, but contiguous across vectorized dimension
52+
-- Constant across vectorized dimension, represented as a scalar
53+
= Uniform
54+
-- Varying across vectorized dimension, represented as a vector
55+
| Varying
56+
-- Varying, but contiguous across vectorized dimension; represented as a
57+
-- scalar carrying the first value
58+
| Contiguous
5559
| ProdStability [Stability]
5660
deriving (Eq, Show)
5761

@@ -168,25 +172,27 @@ vectorizeLoopsExpr expr = do
168172
narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
169173
let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
170174
case expr of
171-
PrimOp (DAMOp (Seq effs dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal n))) dest body))
172-
| n `mod` loopWidth == 0 -> (do
173-
safe <- vectorSafeEffect effs
174-
if safe
175-
then (do
176-
Distinct <- getDistinct
177-
let vn = n `div` loopWidth
178-
body' <- vectorizeSeq loopWidth body
179-
dest' <- renameM dest
180-
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
181-
return $ PrimOp $ DAMOp seqOp)
182-
else renameM expr)
183-
`catchErr` \errs -> do
184-
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
185-
ctx = mempty { messageCtx = [msg] }
186-
errs' = prependCtxToErrs ctx errs
187-
modify (<> LiftE errs')
188-
recurSeq expr
189-
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
175+
PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
176+
sz <- simplifyIxSize =<< renameM ixty
177+
case sz of
178+
Just n | n `mod` loopWidth == 0 -> (do
179+
safe <- vectorSafeEffect effs
180+
if safe
181+
then (do
182+
Distinct <- getDistinct
183+
let vn = n `div` loopWidth
184+
body' <- vectorizeSeq loopWidth ixty body
185+
dest' <- renameM dest
186+
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
187+
return $ PrimOp $ DAMOp seqOp)
188+
else renameM expr)
189+
`catchErr` \errs -> do
190+
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
191+
ctx = mempty { messageCtx = [msg] }
192+
errs' = prependCtxToErrs ctx errs
193+
modify (<> LiftE errs')
194+
recurSeq expr
195+
_ -> recurSeq expr
190196
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
191197
item' <- renameM item
192198
itemTy <- return $ getType item'
@@ -218,6 +224,15 @@ vectorizeLoopsExpr expr = do
218224
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
219225
recurSeq _ = error "Impossible"
220226

227+
simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m)
228+
=> IxType SimpIR n -> m n (Maybe Word32)
229+
simplifyIxSize ixty = do
230+
sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size []
231+
cheapReduce sizeMethod >>= \case
232+
Just (IdxRepVal n) -> return $ Just n
233+
_ -> return Nothing
234+
{-# INLINE simplifyIxSize #-}
235+
221236
-- Really we should check this by seeing whether there is an instance for a
222237
-- `Commutative` class, or something like that, but for now just pattern-match
223238
-- to detect scalar addition as the only monoid we recognize as commutative.
@@ -300,22 +315,27 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
300315
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
301316
safe _ = return False
302317

303-
vectorizeSeq :: Word32 -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
304-
vectorizeSeq loopWidth (UnaryLamExpr (b:>ty) body) = do
305-
(_, ty') <- case ty of
306-
ProdTy [ixTy, ref] -> do
307-
ixTy' <- renameM ixTy
318+
vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i
319+
-> TopVectorizeM i o (LamExpr SimpIR o)
320+
vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do
321+
newLoopTy <- case ty of
322+
ProdTy [_ixType, ref] -> do
308323
ref' <- renameM ref
309-
return (ixTy', ProdTy [IdxRepTy, ref'])
324+
return $ ProdTy [IdxRepTy, ref']
310325
_ -> error "Unexpected seq binder type"
326+
ixty' <- renameM ixty
311327
liftVectorizeM loopWidth $
312-
buildUnaryLamExpr (getNameHint b) ty' \ci -> do
313-
-- XXX: we're assuming `Fin n` here
328+
buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do
329+
-- The per-tile loop iterates on `Fin`
314330
(viOrd, dest) <- fromPair $ Var ci
315331
iOrd <- imul viOrd $ IdxRepVal loopWidth
316-
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal iOrd dest)) $
332+
-- TODO: It would be nice to cancel this UnsafeFromOrdinal with the
333+
-- Ordinal that will be taken later when indexing, but that should
334+
-- probably be a separate pass.
335+
i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd]
336+
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $
317337
vectorizeBlock body $> UnitVal
318-
vectorizeSeq _ _ = error "expected a unary lambda expression"
338+
vectorizeSeq _ _ _ = error "expected a unary lambda expression"
319339

320340
newtype VectorizeM i o a =
321341
VectorizeM { runVectorizeM ::

0 commit comments

Comments
 (0)