Skip to content

Commit 4c47ab0

Browse files
authored
Merge pull request #1316 from axch/vectorize-user-index-sets
Vectorize through user-defined index sets
2 parents 3fbcc02 + c8c0ae3 commit 4c47ab0

File tree

3 files changed

+117
-41
lines changed

3 files changed

+117
-41
lines changed

src/lib/ImpToLLVM.hs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -902,16 +902,33 @@ withWidthOfFP x template = case typeOf template of
902902
L.FloatingPointType L.FloatFP -> litVal $ Float32Lit $ realToFrac x
903903
_ -> error $ "Unsupported floating point type: " ++ show (typeOf template)
904904

905+
-- If we are accessing a `L.Type` from a Dex array, what memory alignment (in
906+
-- bytes) can we guarantee? This is probably better expressed in Dex types, but
907+
-- we would need to plumb them to do it that way. 1-byte alignment should
908+
-- always be safe, but we can promise higher-performance alignments for some
909+
-- types.
910+
dexAlignment :: L.Type -> Word32
911+
dexAlignment = \case
912+
L.IntegerType bits | bits `mod` 8 == 0 -> bits `div` 8
913+
L.IntegerType _ -> 1
914+
L.PointerType _ _ -> 4
915+
L.FloatingPointType L.FloatFP -> 4
916+
L.FloatingPointType L.DoubleFP -> 8
917+
L.VectorType _ eltTy -> dexAlignment eltTy
918+
_ -> 1
919+
905920
store :: LLVMBuilder m => Operand -> Operand -> m ()
906-
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing 0 []
921+
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing alignment [] where
922+
alignment = dexAlignment $ typeOf x
907923

908924
load :: LLVMBuilder m => L.Type -> Operand -> m Operand
909925
load pointeeTy ptr =
910926
#if MIN_VERSION_llvm_hs(15,0,0)
911-
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing 0 []
927+
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing alignment []
912928
#else
913-
emitInstr pointeeTy $ L.Load False ptr Nothing 0 []
929+
emitInstr pointeeTy $ L.Load False ptr Nothing alignment []
914930
#endif
931+
where alignment = dexAlignment pointeeTy
915932

916933
ilt :: LLVMBuilder m => Operand -> Operand -> m Operand
917934
ilt x y = emitInstr i1 $ L.ICmp IP.SLT x y []

src/lib/Vectorize.hs

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ import Util (allM, zipWithZ)
4949
-- TODO: Local vector values? We might want to pack short and pure for loops into vectors,
5050
-- to support things like float3 etc.
5151
data Stability
52-
= Uniform -- constant across vectorized dimension
53-
| Varying -- varying across vectorized dimension
54-
| Contiguous -- varying, but contiguous across vectorized dimension
52+
-- Constant across vectorized dimension, represented as a scalar
53+
= Uniform
54+
-- Varying across vectorized dimension, represented as a vector
55+
| Varying
56+
-- Varying, but contiguous across vectorized dimension; represented as a
57+
-- scalar carrying the first value
58+
| Contiguous
5559
| ProdStability [Stability]
5660
deriving (Eq, Show)
5761

@@ -168,25 +172,27 @@ vectorizeLoopsExpr expr = do
168172
narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
169173
let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
170174
case expr of
171-
PrimOp (DAMOp (Seq effs dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal n))) dest body))
172-
| n `mod` loopWidth == 0 -> (do
173-
safe <- vectorSafeEffect effs
174-
if safe
175-
then (do
176-
Distinct <- getDistinct
177-
let vn = n `div` loopWidth
178-
body' <- vectorizeSeq loopWidth body
179-
dest' <- renameM dest
180-
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
181-
return $ PrimOp $ DAMOp seqOp)
182-
else renameM expr)
183-
`catchErr` \errs -> do
184-
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
185-
ctx = mempty { messageCtx = [msg] }
186-
errs' = prependCtxToErrs ctx errs
187-
modify (<> LiftE errs')
188-
recurSeq expr
189-
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
175+
PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
176+
sz <- simplifyIxSize =<< renameM ixty
177+
case sz of
178+
Just n | n `mod` loopWidth == 0 -> (do
179+
safe <- vectorSafeEffect effs
180+
if safe
181+
then (do
182+
Distinct <- getDistinct
183+
let vn = n `div` loopWidth
184+
body' <- vectorizeSeq loopWidth ixty body
185+
dest' <- renameM dest
186+
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
187+
return $ PrimOp $ DAMOp seqOp)
188+
else renameM expr)
189+
`catchErr` \errs -> do
190+
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
191+
ctx = mempty { messageCtx = [msg] }
192+
errs' = prependCtxToErrs ctx errs
193+
modify (<> LiftE errs')
194+
recurSeq expr
195+
_ -> recurSeq expr
190196
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
191197
item' <- renameM item
192198
itemTy <- return $ getType item'
@@ -218,6 +224,15 @@ vectorizeLoopsExpr expr = do
218224
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
219225
recurSeq _ = error "Impossible"
220226

227+
simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m)
228+
=> IxType SimpIR n -> m n (Maybe Word32)
229+
simplifyIxSize ixty = do
230+
sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size []
231+
cheapReduce sizeMethod >>= \case
232+
Just (IdxRepVal n) -> return $ Just n
233+
_ -> return Nothing
234+
{-# INLINE simplifyIxSize #-}
235+
221236
-- Really we should check this by seeing whether there is an instance for a
222237
-- `Commutative` class, or something like that, but for now just pattern-match
223238
-- to detect scalar addition as the only monoid we recognize as commutative.
@@ -300,22 +315,27 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
300315
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
301316
safe _ = return False
302317

303-
vectorizeSeq :: Word32 -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
304-
vectorizeSeq loopWidth (UnaryLamExpr (b:>ty) body) = do
305-
(_, ty') <- case ty of
306-
ProdTy [ixTy, ref] -> do
307-
ixTy' <- renameM ixTy
318+
vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i
319+
-> TopVectorizeM i o (LamExpr SimpIR o)
320+
vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do
321+
newLoopTy <- case ty of
322+
ProdTy [_ixType, ref] -> do
308323
ref' <- renameM ref
309-
return (ixTy', ProdTy [IdxRepTy, ref'])
324+
return $ ProdTy [IdxRepTy, ref']
310325
_ -> error "Unexpected seq binder type"
326+
ixty' <- renameM ixty
311327
liftVectorizeM loopWidth $
312-
buildUnaryLamExpr (getNameHint b) ty' \ci -> do
313-
-- XXX: we're assuming `Fin n` here
328+
buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do
329+
-- The per-tile loop iterates on `Fin`
314330
(viOrd, dest) <- fromPair $ Var ci
315331
iOrd <- imul viOrd $ IdxRepVal loopWidth
316-
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal iOrd dest)) $
332+
-- TODO: It would be nice to cancel this UnsafeFromOrdinal with the
333+
-- Ordinal that will be taken later when indexing, but that should
334+
-- probably be a separate pass.
335+
i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd]
336+
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $
317337
vectorizeBlock body $> UnitVal
318-
vectorizeSeq _ _ = error "expected a unary lambda expression"
338+
vectorizeSeq _ _ _ = error "expected a unary lambda expression"
319339

320340
newtype VectorizeM i o a =
321341
VectorizeM { runVectorizeM ::
@@ -467,9 +487,13 @@ vectorizePrimOp op = case op of
467487
BinOp opk arg1 arg2 -> do
468488
sx@(VVal vx x) <- vectorizeAtom arg1
469489
sy@(VVal vy y) <- vectorizeAtom arg2
470-
let v = case (vx, vy) of (Uniform, Uniform) -> Uniform; _ -> Varying
471-
x' <- if vx /= v then ensureVarying sx else return x
472-
y' <- if vy /= v then ensureVarying sy else return y
490+
let v = case (opk, vx, vy) of
491+
(_, Uniform, Uniform) -> Uniform
492+
(IAdd, Uniform, Contiguous) -> Contiguous
493+
(IAdd, Contiguous, Uniform) -> Contiguous
494+
_ -> Varying
495+
x' <- if v == Varying then ensureVarying sx else return x
496+
y' <- if v == Varying then ensureVarying sy else return y
473497
VVal v <$> emitOp (BinOp opk x' y')
474498
MiscOp (CastOp tyArg arg) -> do
475499
ty <- vectorizeType tyArg

tests/opt-tests.dx

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ _ = for i:(Fin 20) j:(Fin 4). ordinal j
126126
"vectorizing int binary op"
127127
-- CHECK-LABEL: vectorizing int binary op
128128
%passes vect
129-
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1
129+
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) * 2
130130
-- CHECK: seq (RawFin 0x10)
131131
-- CHECK: [[i0:v#[0-9]+]]:<16xInt32> = vbroadcast
132132
-- CHECK: [[i1:v#[0-9]+]]:<16xInt32> = viota
133133
-- CHECK: [[i2:v#[0-9]+]]:<16xInt32> = %iadd [[i0]] [[i1]]
134-
-- CHECK: [[ones:v#[0-9]+]]:<16xInt32> = vbroadcast 1
135-
-- CHECK: %iadd [[i2]] [[ones]]
134+
-- CHECK: [[twos:v#[0-9]+]]:<16xInt32> = vbroadcast 2
135+
-- CHECK: %imul [[i2]] [[twos]]
136136

137137
"vectorizing float binary op"
138138
-- CHECK-LABEL: vectorizing float binary op
@@ -211,3 +211,38 @@ _ = yield_accum (AddMonoid Int32) \result.
211211
-- CHECK: [[mat1:v#[0-9]+]]:<16xInt32> = vbroadcast
212212
-- CHECK: [[prodj:v#[0-9]+]]:<16xInt32> = %imul [[mat1]] [[mat2j]]
213213
-- CHECK: extend [[refj]] [[prodj]]
214+
215+
"vectorizing through the `tile` combinator and its funny index set"
216+
-- CHECK-LABEL: vectorizing through the `tile` combinator and its funny index set
217+
218+
%passes vect
219+
_ = yield_accum (AddMonoid Int32) \result.
220+
tile((Fin 256), 32) \set.
221+
for_ i:set.
222+
ix = inject(i, to=(Fin 256))
223+
result!ix += xs[ix]
224+
-- CHECK: seq (RawFin 0x8)
225+
-- CHECK: seq (RawFin 0x2)
226+
-- CHECK: [[refix:v#[0-9]+]]:(Ref {{v#[0-9]+}} <16xInt32>) = vrefslice
227+
-- CHECK: [[xsix:v#[0-9]+]]:<16xInt32> =
228+
-- CHECK-NEXT: vslice
229+
-- CHECK: extend [[refix]] [[xsix]]
230+
231+
"Non-aligned"
232+
-- CHECK-LABEL: Non-aligned
233+
234+
-- This is a regression test. We are checking that Dex-side
235+
-- vectorization does not end up assuming that arrays are aligned on
236+
-- the size of the vectors, only on the size of the underlying
237+
-- scalars.
238+
239+
non_aligned = for i:(Fin 7). for j:(Fin 257). +0
240+
241+
%passes llvm
242+
_ = yield_accum (AddMonoid Int32) \result.
243+
tile((Fin 257), 32) \set.
244+
for_ i:set.
245+
ix = inject(i, to=(Fin 257))
246+
result!(6@(Fin 7))!ix += non_aligned[6@_][ix]
247+
-- CHECK: load <16 x i32>, <16 x i32>* %"v#{{[0-9]+}}", align 4
248+
-- CHECK: store <16 x i32> %"v#{{[0-9]+}}", <16 x i32>* %"v#{{[0-9]+}}", align 4

0 commit comments

Comments
 (0)