Skip to content

Commit be93a5c

Browse files
gflegarMoerafaat
authored andcommitted
[BACKEND] Update LLVM version to llvm/llvm-project@7752e0a (triton-lang#6735)
Removed `addArgumentMaterialization` since that method was removed from MLIR in llvm/llvm-project@23e3cbb
1 parent bc33cfd commit be93a5c

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,207 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
296296
DecomposedWarpConversion decomposed,
297297
OpAdaptor adaptor,
298298
ConversionPatternRewriter &rewriter) const;
299+
300+
SmallVector<Value>
301+
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
302+
const LinearLayout &srcLayout,
303+
const LinearLayout &dstLayout,
304+
ConversionPatternRewriter &rewriter) const {
305+
MLIRContext *ctx = op.getContext();
306+
auto loc = op.getLoc();
307+
auto b = TritonLLVMOpBuilder(loc, rewriter);
308+
309+
StringAttr kRegister = str_attr("register");
310+
StringAttr kLane = str_attr("lane");
311+
StringAttr kWarp = str_attr("warp");
312+
StringAttr kBlock = str_attr("block");
313+
StringAttr kOffset = str_attr("offset");
314+
StringAttr kIteration = str_attr("iteration");
315+
316+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
317+
318+
auto scratchConfig =
319+
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
320+
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
321+
op.getSrc().getType().getEncoding(), op.getType().getShape()));
322+
// Input dims: [offset, iteration, block]
323+
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
324+
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
325+
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);
326+
327+
// Layout for the store from registers to shared memory.
328+
//
329+
// Note: If two threads in the same warp write to the same shmem offset, the
330+
// hardware resolves that without a stall or a bank conflict. Therefore we
331+
// don't need to avoid duplicate writes.
332+
// Input dims: [reg, lane, warp]
333+
// Output dims: [offset, iteration]
334+
bool isStMatrix = targetInfo.canUseStMatrix(
335+
op.getSrc().getType(), scratchConfig.repShape,
336+
scratchConfig.paddedRepShape, scratchConfig.order,
337+
/*swizzleByteSize=*/0);
338+
LinearLayout shmemStoreLayout =
339+
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
340+
/*swizzleByteSize=*/0)
341+
: srcLayout.invertAndCompose(sharedLayout);
342+
343+
const int shmemAllocatedNumElems =
344+
getNumScratchElements(scratchConfig.paddedRepShape);
345+
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);
346+
347+
// Layout for the load from shmem to registers.
348+
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);
349+
350+
// Check that the `register` fully determines the `iteration`. That is,
351+
// each thread does exactly the same reads and writes to shmem on each
352+
// iteration, just with different input/output registers.
353+
assert(
354+
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
355+
assert(
356+
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
357+
358+
// iteration -> registers
359+
SmallVector<SmallVector<int>> inRegsForIter =
360+
collectRegsForIter(ctx, shmemStoreLayout);
361+
SmallVector<SmallVector<int>> outRegsForIter =
362+
collectRegsForIter(ctx, shmemLoadLayout);
363+
364+
Value smemBase =
365+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
366+
auto sharedPtrTy = smemBase.getType();
367+
Type elemTy = inVals[0].getType();
368+
auto outSize = shmemLoadLayout.getInDimSize(kRegister);
369+
auto iterations = sharedLayout.getInDimSize(kIteration);
370+
assert(scratchConfig.inVec * iterations <= inVals.size());
371+
assert(scratchConfig.outVec * iterations <= outSize);
372+
373+
// Check only one dimension has been padded.
374+
// This means the difference between the padded shape and the original shape
375+
// should only be in one dimension, specifically in
376+
// `scratchConfig.order[0]`.
377+
auto rank = scratchConfig.repShape.size();
378+
for (auto i = 0; i < rank; i++) {
379+
if (i == scratchConfig.order[0]) {
380+
continue;
381+
}
382+
assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]);
383+
}
384+
auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]];
385+
auto paddedSize =
386+
scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride;
387+
388+
// Linear layout function is split in two parts below:
389+
//
390+
// L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
391+
// offset = regBase xor regIdx
392+
//
393+
// It is the same hack as what we've done in the emitIndices function to get
394+
// around performance issues on AMD GPUs
395+
auto getVecAddr = [&](LinearLayout &layout, Value &regBase,
396+
int regSlice) -> Value {
397+
auto regIdx = layout
398+
.apply({{kRegister, regSlice},
399+
{kLane, 0},
400+
{kWarp, 0},
401+
{kBlock, 0}})[0]
402+
.second;
403+
Value offset = b.xor_(regBase, b.i32_val(regIdx));
404+
if (paddedSize > 0) {
405+
assert(llvm::isPowerOf2_32(paddedStride));
406+
assert(llvm::isPowerOf2_32(paddedSize));
407+
auto rshiftVal = llvm::Log2_32(paddedStride);
408+
auto lshiftVal = llvm::Log2_32(paddedSize);
409+
offset = b.add(
410+
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
411+
offset);
412+
}
413+
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset);
414+
vecAddr.setNoWrapFlags(mlir::LLVM::GEPNoWrapFlags::inbounds);
415+
return vecAddr;
416+
};
417+
418+
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
419+
{{kRegister, b.i32_val(0)},
420+
{kLane, laneId},
421+
{kWarp, warpId},
422+
{kBlock, b.i32_val(0)}})[0]
423+
.second;
424+
auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout,
425+
{{kRegister, b.i32_val(0)},
426+
{kLane, laneId},
427+
{kWarp, warpId},
428+
{kBlock, b.i32_val(0)}})[0]
429+
.second;
430+
// register idx -> Value
431+
llvm::MapVector<int, Value> outVals;
432+
for (int i = 0; i < iterations; i++) {
433+
if (i != 0)
434+
b.barrier();
435+
436+
auto &inRegs = inRegsForIter[i];
437+
auto &outRegs = outRegsForIter[i];
438+
439+
// When using `stmatrix`, we can store `inVec` elements even if they are
440+
// not contiguous
441+
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
442+
: scratchConfig.inVec;
443+
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
444+
auto inRegSlice = inRegs[j];
445+
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
446+
SmallVector<Value> inValsVec;
447+
for (int k = 0; k < inVec; k++)
448+
inValsVec.push_back(inVals[inRegSlice + k]);
449+
Value valsVec = packLLVector(loc, inValsVec, rewriter);
450+
if (isStMatrix) {
451+
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
452+
} else {
453+
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
454+
/*pred=*/b.true_val());
455+
}
456+
}
457+
458+
b.barrier();
459+
460+
for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) {
461+
auto outRegSlice = outRegs[j];
462+
auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice);
463+
Value valsVec =
464+
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
465+
vec_ty(elemTy, scratchConfig.outVec),
466+
/*pred=*/b.true_val());
467+
for (Value v : unpackLLVector(loc, valsVec, rewriter))
468+
outVals[outRegSlice++] = v;
469+
}
470+
}
471+
472+
SmallVector<Value> outValsVec;
473+
for (size_t i = 0; i < outVals.size(); i++)
474+
outValsVec.push_back(outVals[i]);
475+
return outValsVec;
476+
}
477+
478+
// Determine which registers are read/written in which iteration of the shmem
479+
// transfer specified by `layout`.
480+
SmallVector<SmallVector<int> /*registers*/>
481+
collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const {
482+
StringAttr kRegister = str_attr("register");
483+
StringAttr kLane = str_attr("lane");
484+
StringAttr kWarp = str_attr("warp");
485+
StringAttr kBlock = str_attr("block");
486+
StringAttr kIteration = str_attr("iteration");
487+
488+
// The choice of iteration should be determined only by the register. That
489+
// is, it should be correct to split the register dimension into iterations.
490+
assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
491+
492+
LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration});
493+
SmallVector<SmallVector<int>> ret(sublayout.getOutDimSize(kIteration));
494+
for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) {
495+
auto idx = sublayout.apply({{kRegister, reg}});
496+
ret[idx.begin()->second].push_back(reg);
497+
}
498+
return ret;
499+
}
299500
};
300501

301502
} // namespace

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "triton/Dialect/Triton/IR/Dialect.h"
2727
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2828
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
29+
#include "llvm/TargetParser/TargetParser.h"
2930

3031
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"
3132

@@ -88,6 +89,15 @@ struct ConvertTritonAMDGPUToLLVM
8889
mod.emitError("unsupported target: '") << this->arch.getValue() << "'";
8990
return signalPassFailure();
9091
}
92+
llvm::StringRef chipset =
93+
llvm::AMDGPU::getArchNameAMDGCN(targetInfo.getGPUKind());
94+
llvm::FailureOr<mlir::amdgpu::Chipset> maybeChipset =
95+
mlir::amdgpu::Chipset::parse(chipset);
96+
if (failed(maybeChipset)) {
97+
mlir::emitError(mlir::UnknownLoc::get(&getContext()),
98+
"Invalid chipset name: " + chipset);
99+
return signalPassFailure();
100+
}
91101

92102
mlir::LowerToLLVMOptions option(context);
93103
option.overrideIndexBitwidth(32);

0 commit comments

Comments
 (0)