Skip to content

Commit bae3b79

Browse files
authored
[BACKEND] Take indices rather than reps as inputs to {smem,tmem}Load (#8623)
The previous code was doing a bad job at trying to guess the CTA-level size of the tile being lowered. Here, we completely give up and instead ask the caller to provide the starting coordinates of the subtensor they want to lower rather than the reps. In the passing, we also switch the `tmemLoad` logic to use LinearLayouts Fixes triton-lang/triton#8606
1 parent 781273a commit bae3b79

File tree

5 files changed

+73
-88
lines changed

5 files changed

+73
-88
lines changed

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
7474

7575
// -----
7676

77+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [16, 2], instrShape = [16, 256, 16]}>
78+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
79+
#smem = #ttg.shared_memory
80+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
81+
// CHECK-LABEL: @warp_group_dot_bf16_32_warps
82+
tt.func @warp_group_dot_bf16_32_warps(
83+
%a: !ttg.memdesc<256x128xbf16, #shared, #smem>,
84+
%b: !ttg.memdesc<128x512xbf16, #shared, #smem>,
85+
%acc: tensor<256x512xf32, #mma>) {
86+
%res = ttng.warp_group_dot %a, %b, %acc {inputPrecision = 0 : i32, isAsync = true} :
87+
!ttg.memdesc<256x128xbf16, #shared, #smem> * !ttg.memdesc<128x512xbf16, #shared, #smem> -> tensor<256x512xf32, #mma>
88+
// CHECK: nvgpu.wgmma {{.*}} k = 16 : i32, layoutA = 1 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32}
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
7795
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
7896
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
7997
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ struct MemDescOperand {
4141
class DotOpMmaMemLoader {
4242
public:
4343
virtual ~DotOpMmaMemLoader() = default;
44+
// Given the starting coordinates of the logical tensor (i.e. reps *
45+
// ctaTileSize), return the associated memory descriptor for SMEM / TMEM.
4446
virtual MemDescOperand memLoad(int a, int b,
4547
ConversionPatternRewriter &rewriter,
4648
Location loc) const = 0;
@@ -50,10 +52,8 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
5052
public:
5153
DotOpMmaSmemLoader() = default;
5254

53-
DotOpMmaSmemLoader(MMASMEMDescriptor desc, Value baseb128, LinearLayout llInv,
54-
ArrayRef<unsigned> instrShape)
55-
: desc(desc), baseb128(baseb128), ll(std::move(llInv)),
56-
instrShape(instrShape) {}
55+
DotOpMmaSmemLoader(MMASMEMDescriptor desc, Value baseb128, LinearLayout llInv)
56+
: desc(desc), baseb128(baseb128), ll(std::move(llInv)) {}
5757

5858
static DotOpMmaSmemLoader
5959
build(Location loc, RewriterBase &rewriter, gpu::MemDescType memTy,
@@ -136,15 +136,6 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
136136
{{kWarp, warpId}})[0]
137137
.second;
138138
baseSrcb128 = b.add(baseSrcb128, warpStrideb128);
139-
// Increase the instruction shape to describe the size at a block level
140-
// as the input just describes it at a warp level
141-
int logwgAlongMN = 0;
142-
for (int i = 0; i < warpGroupToOffsetb128.getInDimSizeLog2(kWarp); i++) {
143-
if (warpGroupToOffsetb128.getBasis(kWarp, i, kOffset) != 0) {
144-
logwgAlongMN++;
145-
}
146-
}
147-
instrShape[MNdim] *= (1 << logwgAlongMN);
148139
}
149140

150141
for (auto [dim, instrSize] : llvm::zip(ll.getInDimNames(), instrShape)) {
@@ -155,22 +146,18 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
155146
auto desc = getDescriptor(ll, instrShape, bitwidth, MNdim, mmaVersion);
156147

157148
Value baseb128 = b.zext(i64_ty, b.and_(baseSrcb128, b.i32_val(0x3FFF)));
158-
return {desc, baseb128, ll, instrShape};
149+
return {desc, baseb128, ll};
159150
}
160151

161152
Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter,
162153
Location loc) const {
163154
auto *ctx = loc.getContext();
164155
auto tb = TritonLLVMOpBuilder(loc, rewriter);
165156
auto dims = to_vector(ll.getInDimNames());
166-
assert((a + 1) * instrShape[0] <= ll.getInDimSize(dims[0]));
167-
assert((b + 1) * instrShape[1] <= ll.getInDimSize(dims[1]));
168157
assert(to_vector(ll.getOutDimNames()) ==
169158
llvm::to_vector(
170159
ArrayRef<StringAttr>{str_attr("offset"), str_attr("block")}));
171-
int32_t totalOffElems = ll.apply({{dims[0], a * instrShape[0]},
172-
{dims[1], b * instrShape[1]}})[0]
173-
.second;
160+
int32_t totalOffElems = ll.apply({{dims[0], a}, {dims[1], b}})[0].second;
174161
int32_t smemByteOffsetb8 = totalOffElems * desc.bitwidth / 8;
175162
auto currDesc = desc.descriptor;
176163
// Take the next 0/1/2/3 bits after the 128b tile
@@ -194,7 +181,6 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
194181
MMASMEMDescriptor desc;
195182
Value baseb128;
196183
LinearLayout ll;
197-
SmallVector<unsigned> instrShape;
198184

199185
static MMASMEMDescriptor getDescriptor(const LinearLayout &ll,
200186
ArrayRef<unsigned> instrShape,
@@ -337,9 +323,9 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
337323
class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
338324
public:
339325
DotOpMmaV5TmemLoader() {}
340-
DotOpMmaV5TmemLoader(Value tensor, Value base,
341-
SmallVector<unsigned int> instrShape, bool interleaved,
342-
bool trans);
326+
static DotOpMmaV5TmemLoader build(Location loc, RewriterBase &rewriter,
327+
gpu::MemDescType memTy, Value tmemBase);
328+
343329
MemDescOperand tmemLoad(int a, int b, ConversionPatternRewriter &rewriter,
344330
Location loc) const;
345331

@@ -349,14 +335,12 @@ class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
349335
}
350336

351337
private:
352-
Value base;
353-
bool trans;
354-
bool interleaved;
355-
bool unpacked;
356-
SmallVector<unsigned int> instrShape;
357-
int numElementsPer32b;
358-
int numRepM;
359-
int numSlicePerBlockN;
338+
DotOpMmaV5TmemLoader(LinearLayout ll, Value address, int bitwidth)
339+
: ll(std::move(ll)), address(address), bitwidth(bitwidth) {}
340+
341+
LinearLayout ll;
342+
Value address;
343+
int bitwidth;
360344
};
361345

362346
static Value getOffsetedBase(Value v, gpu::MemDescType memDescTy,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,25 @@ using ::mlir::triton::gpu::SharedLinearEncodingAttr;
1818
// DotOpMmaV5TmemLoader
1919
//===----------------------------------------------------------------------===//
2020

21-
mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader(
22-
Value tensor, Value base, SmallVector<unsigned int> instrShape,
23-
bool interleaved, bool trans)
24-
: base(base), instrShape(instrShape), interleaved(interleaved),
25-
trans(trans) {
26-
auto ty = cast<MemDescType>(tensor.getType());
27-
auto tmemEncoding = cast<ttng::TensorMemoryEncodingAttr>(ty.getEncoding());
28-
int elTyWidth = ty.getElementTypeBitWidth();
29-
unpacked = tmemEncoding.getColStride() != 1;
30-
// When using TMEM to store operands mma operands the TMEM block size may be
31-
// smaller than mma k block. Therefore we need to adjust the offset
32-
// calculation.
33-
numSlicePerBlockN = tmemEncoding.getBlockN() / instrShape[1];
34-
numElementsPer32b = 32 / (elTyWidth * tmemEncoding.getColStride());
35-
auto shapePerCTA = triton::gpu::getShapePerCTA(ty);
36-
numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0]);
21+
DotOpMmaV5TmemLoader mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::build(
22+
Location loc, RewriterBase &rewriter, gpu::MemDescType memTy,
23+
Value tmemBase) {
24+
auto ctx = loc.getContext();
25+
auto ll = toLinearLayout(memTy);
26+
auto layout = cast<ttng::TensorMemoryEncodingAttr>(memTy.getEncoding());
27+
auto bitwidth = memTy.getElementTypeBitWidth();
28+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
29+
Value address = tb.ptrtoint(i32_ty, tmemBase);
30+
return DotOpMmaV5TmemLoader(ll.pseudoinvert(), address, bitwidth);
3731
}
3832

3933
MemDescOperand mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
4034
int a, int b, ConversionPatternRewriter &rewriter, Location loc) const {
41-
int numRows = 64;
42-
if (interleaved || instrShape[0] >= 128)
43-
numRows = 128;
44-
int numColPerBlock =
45-
((instrShape[0] * numSlicePerBlockN * instrShape[1]) / numRows) /
46-
numElementsPer32b;
47-
int blockId = a + (b / numSlicePerBlockN) * numRepM;
48-
int offset;
49-
if (!interleaved) {
50-
offset = numColPerBlock * blockId;
51-
} else {
52-
int blockIdIsOdd = blockId & 1;
53-
int blockIdPrevEven = blockId - blockIdIsOdd;
54-
offset = numColPerBlock * blockIdPrevEven + ((16 * blockIdIsOdd) << 16);
55-
}
56-
offset += (b % numSlicePerBlockN) * (instrShape[1] / numElementsPer32b);
57-
auto tb = TritonLLVMOpBuilder(loc, rewriter);
58-
Value address = tb.ptrtoint(i32_ty, base);
35+
auto dims = to_vector(ll.getInDimNames());
36+
auto rowCol = ll.apply({{dims[0], a}, {dims[1], b}});
37+
int row = rowCol[0].second;
38+
int col = rowCol[1].second * bitwidth / 32;
39+
int offset = col | (row << 16);
5940
return {address, offset};
6041
}
6142

@@ -445,8 +426,8 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
445426
std::unique_ptr<DotOpMmaMemLoader> aLoader;
446427
bool transA = false;
447428
if (aInTmem) {
448-
aLoader = std::make_unique<DotOpMmaV5TmemLoader>(a, baseA, aOperandShape,
449-
interleaved, transA);
429+
aLoader = std::make_unique<DotOpMmaV5TmemLoader>(
430+
DotOpMmaV5TmemLoader::build(loc, rewriter, aTensorTy, baseA));
450431
} else {
451432
auto isFp4a = op.numBitsPerElementA == 4;
452433
aLoader = std::make_unique<DotOpMmaSmemLoader>(DotOpMmaSmemLoader::build(
@@ -479,8 +460,9 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
479460
Value useInitAcc = useDFlag;
480461
MemDescOperand accAddress = op.getAccAddress(rewriter, loc, m, n, desc);
481462
for (int k = 0; k < numRepK; k++) {
482-
MemDescOperand a = aLoader->memLoad(m, k, rewriter, loc);
483-
Value b = bLoader.smemLoad(k, n, rewriter, loc);
463+
MemDescOperand a =
464+
aLoader->memLoad(m * mmaSizeM, k * mmaSizeK, rewriter, loc);
465+
Value b = bLoader.smemLoad(k * mmaSizeK, n * mmaSizeN, rewriter, loc);
484466
op.createMMAInst(rewriter, loc, accAddress, a, b, elect, useInitAcc,
485467
desc, m, n, k);
486468
useInitAcc = tb.i1_val(1);
@@ -503,6 +485,7 @@ void convertDot(const LLVMTypeConverter &typeConverter,
503485
MemDescType aTensorTy = op.getA().getType();
504486
MemDescType bTensorTy = op.getB().getType();
505487
MemDescType dTensorTy = op.getD().getType();
488+
auto dLayout = cast<ttng::TensorMemoryEncodingAttr>(dTensorTy.getEncoding());
506489
bool twoCTAs = op.getTwoCtas();
507490

508491
DotConversion dot;
@@ -518,12 +501,12 @@ void convertDot(const LLVMTypeConverter &typeConverter,
518501
dot.numBitsPerElementA = aTensorTy.getElementTypeBitWidth();
519502
dot.numBitsPerElementB = bTensorTy.getElementTypeBitWidth();
520503

504+
DotOpMmaV5TmemLoader dLoader =
505+
DotOpMmaV5TmemLoader::build(loc, rewriter, dTensorTy, adaptor.getD());
521506
dot.getAccAddress = [&](ConversionPatternRewriter &rewriter, Location loc,
522507
int m, int n, const DotConversion::InstDesc &desc) {
523-
DotOpMmaV5TmemLoader dLoader = DotOpMmaV5TmemLoader(
524-
op.getD(), adaptor.getD(), {desc.mmaSizeM, desc.mmaSizeN},
525-
desc.interleaved, /*trans=*/false);
526-
return dLoader.tmemLoad(m, n, rewriter, loc);
508+
return dLoader.tmemLoad(m * dLayout.getBlockM(), n * dLayout.getBlockN(),
509+
rewriter, loc);
527510
};
528511

529512
dot.createMMAInst = [&](ConversionPatternRewriter &rewriter, Location loc,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,19 +204,21 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
204204
auto baseB = getOffsetedBase(loadedB, cast<MemDescType>(bTensorTy),
205205
typeConverter, rewriter, loc);
206206
auto dShapePerCTA = getShapePerCTA(dTensorTy);
207-
auto instrShape = mmaEncoding.getInstrShape();
208-
auto accSize = 2 * (instrShape[1] / 4);
209-
unsigned M = 4 * instrShape[0];
210-
unsigned N = instrShape[1];
211-
unsigned K = instrShape[2];
212-
bool zeroAcc = isZeroConst(c);
213207
auto instrMNK = mmaEncoding.getInstrShape();
208+
auto accSize = 2 * (instrMNK[1] / 4);
209+
unsigned M = 4 * instrMNK[0];
210+
unsigned N = instrMNK[1];
211+
unsigned K = instrMNK[2];
212+
bool zeroAcc = isZeroConst(c);
214213
auto warpSize = mmaEncoding.getWarpsPerCTA();
215214
auto shapePerCTATile = SmallVector<unsigned>{instrMNK[0] * warpSize[0],
216215
instrMNK[1] * warpSize[1]};
217-
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
218-
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
219-
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
216+
unsigned mmaSizeM = shapePerCTATile[0];
217+
unsigned mmaSizeN = shapePerCTATile[1];
218+
unsigned mmaSizeK = instrMNK[2];
219+
int numRepM = ceil<unsigned>(dShapePerCTA[0], mmaSizeM);
220+
int numRepN = ceil<unsigned>(dShapePerCTA[1], mmaSizeN);
221+
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], mmaSizeK);
220222
DotOpMmaSmemLoader aLoader;
221223
SmallVector<Value> structA;
222224
auto warpGroups = {warpSize[0] / 4, warpSize[1]};
@@ -270,14 +272,14 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
270272
for (int k = 0; k < numRepK; ++k) {
271273
Value a;
272274
if (aInShared) {
273-
a = aLoader.smemLoad(m, k, rewriter, loc);
275+
a = aLoader.smemLoad(m * mmaSizeM, k * mmaSizeK, rewriter, loc);
274276
} else {
275277
auto aDotOpEnc =
276278
cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
277279
assert(aDotOpEnc.getKWidth() ==
278280
32 / aTensorTy.getElementTypeBitWidth());
279281

280-
unsigned regASize = (instrShape[0] * instrShape[2]) / 32;
282+
unsigned regASize = (instrMNK[0] * instrMNK[2]) / 32;
281283
llvm::SmallVector<Value> regA =
282284
loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize,
283285
regASize, startSequence);
@@ -286,7 +288,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
286288
SmallVector<Type>(regA.size(), regA[0].getType()));
287289
a = packLLElements(loc, typeConverter, regA, rewriter, regATy);
288290
}
289-
auto b = bLoader.smemLoad(k, n, rewriter, loc);
291+
auto b = bLoader.smemLoad(k * mmaSizeK, n * mmaSizeN, rewriter, loc);
290292
numLowPrecisionAcc += K;
291293
// If using native accumulation would cause use to do more low precion
292294
// accumulation than allowed do a separate allocation.

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,7 @@ static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc,
604604
}
605605

606606
for (int col = 0; col < cvt.getInDimSize(kCol); col += instrShape[1]) {
607-
// smemLoad takes the colRep. It'd be nice to change this but we would need
608-
// to change the wgmma and mmav5 lowering
609-
auto desc = loader.smemLoad(0, col / instrShape[1], rewriter, loc);
607+
auto desc = loader.smemLoad(0, col, rewriter, loc);
610608
auto tmemAddr =
611609
b.or_(b.ptrtoint(i32_ty, baseDst), b.i32_val(col * bitwidth / 32),
612610
/*disjoint=*/true);

0 commit comments

Comments
 (0)