@@ -189,12 +189,14 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
189189}
190190
191191Value llLoad (RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
192- Value pred, Value falseVal, triton::CacheModifier cm) {
192+ Value pred, Value falseVal, int64_t alignmentBytes,
193+ triton::CacheModifier cm) {
193194
194195 // Try to emit llvm.intr.masked.load if we can. In theory the backend should
195196 // be happier because we emit less branchy code to optimize. The backend will
196197 // lower it down however it wants at some point.
197- if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
198+ if (alignmentBytes &&
199+ (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE)) {
198200 // `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need
199201 // to bitcast to `vector<1xelemTy>` (and back)
200202 int64_t vecSize = getNumElements (elemTy);
@@ -203,7 +205,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
203205 Value maskVal = createVectorMaskFromPredicate (rewriter, loc, pred, vecSize);
204206 bool nt = (cm == triton::CacheModifier::CG);
205207 Value vecData = rewriter.create <LLVM::MaskedLoadOp>(
206- loc, vecType, ptr, maskVal, falseVal, vecSize , nt);
208+ loc, vecType, ptr, maskVal, falseVal, alignmentBytes , nt);
207209 // If it is not a vector, remember to bitcast back to a scalar
208210 vecData = bitcast (vecData, elemTy);
209211 return vecData;
@@ -237,20 +239,20 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
237239}
238240
239241void llStore (RewriterBase &rewriter, Location loc, Value ptr, Value val,
240- Value pred, triton::CacheModifier cm) {
242+ Value pred, int64_t alignmentBytes, triton::CacheModifier cm) {
241243 // Try to emit llvm.intr.masked.store if we can. In theory the backend should
242244 // be happier because we emit less branchy code to optimize. The backend will
243245 // lower it down however it wants at some point.
244- if (cm == triton::CacheModifier::NONE) {
246+ if (alignmentBytes && cm == triton::CacheModifier::NONE) {
245247 // `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need
246248 // to bitcast to `vector<1xelemTy>`
247249 Type elemTy = val.getType ();
248250 int64_t vecSize = getNumElements (elemTy);
249251 Type vecType = castToVectorType (elemTy);
250252 val = bitcast (val, vecType);
251253 Value maskVal = createVectorMaskFromPredicate (rewriter, loc, pred, vecSize);
252- auto op =
253- rewriter. create <LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vecSize );
254+ auto op = rewriter. create <LLVM::MaskedStoreOp>(loc, val, ptr, maskVal,
255+ alignmentBytes );
254256 return ;
255257 }
256258
0 commit comments