Skip to content

Commit 1df64d1

Browse files
authored
[AMD] Add alignment information to maskedLoad/maskedStore (#4816)
I think we should always set the right alignment to the `maskedload`/`maskedstore` instructions.
1 parent 80947a2 commit 1df64d1

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ struct LoadStoreConversionBase {
119119
return axisAnalysisPass.getMaskAlignment(mask);
120120
}
121121

122+
unsigned getPtrAlignment(Value ptr) const {
123+
return axisAnalysisPass.getPtrAlignment(ptr);
124+
}
125+
122126
protected:
123127
const AMD::TargetInfo &targetInfo;
124128
ModuleAxisInfoAnalysis &axisAnalysisPass;
@@ -193,7 +197,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
193197
// vectorized iteration through all the pointer/mask/other elements
194198
const int valueElemNBits =
195199
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
200+
const size_t valueElemNBytes = valueElemNBits / 8;
196201
const int numVecs = numElems / vec;
202+
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;
197203

198204
auto cacheMod = op.getCache();
199205
SmallVector<Value> loadedVals;
@@ -230,8 +236,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
230236
falseVal = v;
231237
}
232238

233-
auto loadVal =
234-
llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod);
239+
Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal,
240+
ptrAlignmentBytes, cacheMod);
235241
for (size_t ii = 0; ii < vec; ++ii) {
236242
Value vecIdx = createIndexAttrConstant(
237243
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec);
@@ -294,9 +300,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
294300
vec = std::min(vec, maskAlign);
295301
}
296302

297-
const size_t dtsize =
298-
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
299-
const size_t valueElemNBits = dtsize * 8;
303+
const size_t valueElemNBits =
304+
std::max<int>(8, valueElemTy.getIntOrFloatBitWidth());
305+
const size_t valueElemNBytes = valueElemNBits / 8;
306+
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;
300307

301308
auto cacheMod = op.getCache();
302309
const int numVecs = elemsPerThread / vec;
@@ -328,7 +335,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
328335
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
329336
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
330337
}
331-
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
338+
llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod);
332339
} // end vec
333340
rewriter.eraseOp(op);
334341
return success();

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,14 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
189189
}
190190

191191
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
192-
Value pred, Value falseVal, triton::CacheModifier cm) {
192+
Value pred, Value falseVal, int64_t alignmentBytes,
193+
triton::CacheModifier cm) {
193194

194195
// Try to emit llvm.intr.masked.load if we can. In theory the backend should
195196
// be happier because we emit less branchy code to optimize. The backend will
196197
// lower it down however it wants at some point.
197-
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
198+
if (alignmentBytes &&
199+
(cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE)) {
198200
// `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need
199201
// to bitcast to `vector<1xelemTy>` (and back)
200202
int64_t vecSize = getNumElements(elemTy);
@@ -203,7 +205,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
203205
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
204206
bool nt = (cm == triton::CacheModifier::CG);
205207
Value vecData = rewriter.create<LLVM::MaskedLoadOp>(
206-
loc, vecType, ptr, maskVal, falseVal, vecSize, nt);
208+
loc, vecType, ptr, maskVal, falseVal, alignmentBytes, nt);
207209
// If it is not a vector, remember to bitcast back to a scalar
208210
vecData = bitcast(vecData, elemTy);
209211
return vecData;
@@ -237,20 +239,20 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
237239
}
238240

239241
void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
240-
Value pred, triton::CacheModifier cm) {
242+
Value pred, int64_t alignmentBytes, triton::CacheModifier cm) {
241243
// Try to emit llvm.intr.masked.store if we can. In theory the backend should
242244
// be happier because we emit less branchy code to optimize. The backend will
243245
// lower it down however it wants at some point.
244-
if (cm == triton::CacheModifier::NONE) {
246+
if (alignmentBytes && cm == triton::CacheModifier::NONE) {
245247
// `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need
246248
// to bitcast to `vector<1xelemTy>`
247249
Type elemTy = val.getType();
248250
int64_t vecSize = getNumElements(elemTy);
249251
Type vecType = castToVectorType(elemTy);
250252
val = bitcast(val, vecType);
251253
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
252-
auto op =
253-
rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vecSize);
254+
auto op = rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal,
255+
alignmentBytes);
254256
return;
255257
}
256258

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
3030
// Loads from shared or global memory with predication.
3131
// `otherElems` is used to mask out the elements that are not loaded
3232
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
33-
Value pred, Value falseVal,
33+
Value pred, Value falseVal, int64_t alignmentBytes = 0,
3434
triton::CacheModifier cm = triton::CacheModifier::NONE);
3535

3636
// Stores to shared or global memory with predication.
3737
void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
38-
Value pred,
38+
Value pred, int64_t alignmentBytes = 0,
3939
triton::CacheModifier cm = triton::CacheModifier::NONE);
4040
} // namespace mlir::LLVM::AMD
4141

0 commit comments

Comments
 (0)