Skip to content

Commit 4595f3a

Browse files
authored
[Blackwell] Optimize MMAv5 lowering to reduce register usage (#6817)
The MMAv5 instruction supports constant offsets encoded directly in the instruction for TMEM memory descriptors, such as for the `d` operand or if `a` is in TMEM. Using constant offsets reduces register pressure because each new offset doesn't require a register. It also helps a lot when there are pipelined MMAv5 instructions or multiple in the same loop because LLVM will CSE and hoist all the offsets out of the loop and PTXAS will keep them live for the whole loop instead of rematerializing them. This means each `ttng.tc_gen5_mma` can end up using up to 15-20 registers each in the loop because of all the offsets.
1 parent 2c57e20 commit 4595f3a

File tree

3 files changed

+70
-82
lines changed

3 files changed

+70
-82
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
4141
#tmem = #ttng.tensor_memory_encoding<blockM = 64, blockN = 64, unpacked = true>
4242
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
4343
// CHECK-LABEL: @tc_gen5_mma_multi_m_n
44-
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
45-
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
46-
// CHECK-DAG: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32
47-
// CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
48-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]]
49-
// CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C64]] : i32
50-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]]
44+
// CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
45+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
46+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 64 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
5147
// 1048576 = row << 16 + col = 16 << 16 + 0
52-
// CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32
53-
// CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32
54-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]]
48+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048576 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
5549
// 1048640 = row << 16 + col = 16 << 16 + 64
56-
// CHECK: %[[C1048640:.+]] = llvm.mlir.constant(1048640 : i32) : i32
57-
// CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048640]] : i32
58-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]]
50+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048640 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
5951

6052
tt.func @tc_gen5_mma_multi_m_n(%a: !ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>,
6153
%b: !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>,
@@ -82,21 +74,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
8274
#tmem = #ttng.tensor_memory_encoding<blockM = 64, blockN = 32, unpacked = true, CTASplitN = 2>
8375
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
8476
// CHECK-LABEL: @tc_gen5_mma_multi_ctas
85-
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
86-
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
87-
// CHECK-DAG: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32
88-
// CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
89-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]]
90-
// CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C32]] : i32
91-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]]
77+
// CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
78+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
79+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 32 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
9280
// 1048576 = row << 16 + col = 16 << 16 + 0
93-
// CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32
94-
// CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32
95-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]]
81+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048576 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
9682
// 1048640 = row << 16 + col = 16 << 16 + 32
97-
// CHECK: %[[C1048608:.+]] = llvm.mlir.constant(1048608 : i32) : i32
98-
// CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048608]] : i32
99-
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]]
83+
// CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048608 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
10084

10185
tt.func @tc_gen5_mma_multi_ctas(%a: !ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>,
10286
%b: !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>,
@@ -203,12 +187,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
203187
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
204188
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
205189
// CHECK: llvm.cond_br %[[P1]]
206-
// CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
207190
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144708608 : i32) : i32
208-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %arg5
191+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %arg5
209192
// CHECK: %[[TRUE:.+]] = llvm.mlir.constant(true) : i1
210193
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681579536 : i32) : i32
211-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]]
194+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]]
212195
tt.func @tc_gen5_mma_block_scale(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
213196
%b: !ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>,
214197
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -320,12 +303,10 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
320303
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
321304
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
322305
// CHECK-LABEL: @tc_gen5_mma_block_scale_nvfp4
323-
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
324-
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
325-
// CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
306+
// CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
326307
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(138413184 : i32) : i32
327-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]]
328-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]]
308+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC0]]
309+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC0]]
329310
tt.func @tc_gen5_mma_block_scale_nvfp4(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
330311
%b: !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>,
331312
%c: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -356,12 +337,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
356337
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
357338
// CHECK-LABEL: @tc_gen5_mma_block_scale_mxfp4
358339
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
359-
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
360-
// CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
361340
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(146801792 : i32) : i32
362-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]]
341+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC0]]
363342
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(1220543648 : i32) : i32
364-
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC1]]
343+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC1]]
365344
tt.func @tc_gen5_mma_block_scale_mxfp4(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
366345
%b: !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>,
367346
%c: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ union SMEMDescriptor {
2323
};
2424
};
2525

26+
struct MemDescOperand {
27+
Value base;
28+
std::optional<int> offset;
29+
};
30+
2631
// Abstract class to calculate the address of a shared or tensor memory slice.
2732
class DotOpMmaMemLoader {
2833
public:
2934
virtual ~DotOpMmaMemLoader() = default;
30-
virtual Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
31-
Location loc) const = 0;
35+
virtual MemDescOperand memLoad(int a, int b,
36+
ConversionPatternRewriter &rewriter,
37+
Location loc) const = 0;
3238
};
3339

3440
// Helper class to load shared memory slices following MMAv3 layout.
@@ -46,9 +52,9 @@ class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader {
4652
Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter,
4753
Location loc) const;
4854

49-
Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
50-
Location loc) const override {
51-
return smemLoad(a, b, rewriter, loc);
55+
MemDescOperand memLoad(int a, int b, ConversionPatternRewriter &rewriter,
56+
Location loc) const override {
57+
return {smemLoad(a, b, rewriter, loc), std::nullopt};
5258
}
5359

5460
private:
@@ -73,11 +79,11 @@ class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
7379
DotOpMmaV5TmemLoader(Value tensor, Value base,
7480
SmallVector<unsigned int> instrShape, bool interleaved,
7581
bool trans);
76-
Value tmemLoad(int a, int b, ConversionPatternRewriter &rewriter,
77-
Location loc) const;
82+
MemDescOperand tmemLoad(int a, int b, ConversionPatternRewriter &rewriter,
83+
Location loc) const;
7884

79-
Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
80-
Location loc) const override {
85+
MemDescOperand memLoad(int a, int b, ConversionPatternRewriter &rewriter,
86+
Location loc) const override {
8187
return tmemLoad(a, b, rewriter, loc);
8288
}
8389

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,26 @@ mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader(
3232
numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0]);
3333
}
3434

35-
Value mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
35+
MemDescOperand mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
3636
int a, int b, ConversionPatternRewriter &rewriter, Location loc) const {
37-
auto tb = TritonLLVMOpBuilder(loc, rewriter);
3837
int numRows = 64;
3938
if (interleaved || instrShape[0] >= 128)
4039
numRows = 128;
4140
int numColPerBlock =
4241
((instrShape[0] * instrShape[1]) / numRows) / numElementsPer32b;
43-
Value address = base;
4442
int blockId = a + b * numRepM;
45-
address = tb.ptrtoint(i32_ty, address);
43+
int offset;
4644
if (!interleaved) {
47-
address = tb.add(address, tb.i32_val(numColPerBlock * blockId));
45+
offset = numColPerBlock * blockId;
4846
} else {
4947
int blockIdIsOdd = blockId & 1;
5048
int blockIdPrevEven = blockId - blockIdIsOdd;
51-
Value offset = tb.i32_val(numColPerBlock * blockIdPrevEven +
52-
((16 * blockIdIsOdd) << 16));
53-
address = tb.add(address, offset);
49+
offset = numColPerBlock * blockIdPrevEven + ((16 * blockIdIsOdd) << 16);
5450
}
55-
return address;
51+
52+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
53+
Value address = tb.ptrtoint(i32_ty, base);
54+
return {address, offset};
5655
}
5756

5857
//===----------------------------------------------------------------------===//
@@ -229,9 +228,9 @@ static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter,
229228
//===----------------------------------------------------------------------===//
230229

231230
static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
232-
ttng::TCGen5MMAOp op, Value a, Value b, Value d,
233-
Value pred, Value instDescriptor, Value useInitAcc,
234-
bool aInTMem, bool twoCTAs) {
231+
ttng::TCGen5MMAOp op, MemDescOperand a, Value b,
232+
MemDescOperand d, Value pred, Value instDescriptor,
233+
Value useInitAcc, bool aInTMem, bool twoCTAs) {
235234
PTXBuilder ptxBuilder;
236235
std::string opcode =
237236
"tcgen05.mma.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + ".kind::";
@@ -244,9 +243,10 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
244243
opcode += "f8f6f4";
245244
else
246245
assert(0 && "Unsupported type.");
247-
auto *accOp = ptxBuilder.newAddrOperand(d, "r");
248-
auto *aOp = aInTMem ? ptxBuilder.newAddrOperand(a, "r")
249-
: ptxBuilder.newOperand(a, "l");
246+
auto *accOp = ptxBuilder.newAddrOperand(d.base, "r", *d.offset);
247+
assert(a.offset.has_value() == aInTMem);
248+
auto *aOp = aInTMem ? ptxBuilder.newAddrOperand(a.base, "r", *a.offset)
249+
: ptxBuilder.newOperand(a.base, "l");
250250
auto *bOp = ptxBuilder.newOperand(b, "l");
251251
auto *instDescOp = ptxBuilder.newOperand(instDescriptor, "r");
252252
auto *useInitAccOp = ptxBuilder.newOperand(useInitAcc, "b");
@@ -257,10 +257,10 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
257257

258258
static void createScaledGen5MMA(ConversionPatternRewriter &rewriter,
259259
Location loc, ttng::TCGen5MMAScaledOp op,
260-
Value a, Value b, Value d, Value scaleA,
261-
Value scaleB, Value pred, Value instDescriptor,
262-
Value useInitAcc, bool aInTmem,
263-
mxfpKind mxfpInstKind) {
260+
MemDescOperand a, Value b, MemDescOperand d,
261+
Value scaleA, Value scaleB, Value pred,
262+
Value instDescriptor, Value useInitAcc,
263+
bool aInTmem, mxfpKind mxfpInstKind) {
264264
PTXBuilder ptxBuilder;
265265
std::string opcode;
266266
if (mxfpInstKind == mxfpKind::mxf8f6f4) {
@@ -274,9 +274,10 @@ static void createScaledGen5MMA(ConversionPatternRewriter &rewriter,
274274
} else {
275275
assert(0 && "Unsupported mxfp kind.");
276276
}
277-
auto *accOp = ptxBuilder.newAddrOperand(d, "r");
278-
auto *aOp = aInTmem ? ptxBuilder.newAddrOperand(a, "r")
279-
: ptxBuilder.newOperand(a, "l");
277+
auto *accOp = ptxBuilder.newAddrOperand(d.base, "r", *d.offset);
278+
assert(aInTmem == a.offset.has_value());
279+
auto *aOp = aInTmem ? ptxBuilder.newAddrOperand(a.base, "r", *a.offset)
280+
: ptxBuilder.newOperand(a.base, "l");
280281
auto *bOp = ptxBuilder.newOperand(b, "l");
281282
auto *instDescOp = ptxBuilder.newOperand(instDescriptor, "r");
282283
auto *scaleAOp = ptxBuilder.newAddrOperand(scaleA, "r");
@@ -335,11 +336,11 @@ struct DotConversion {
335336
bool aInTmem;
336337
};
337338

338-
using GetAccAddressFn = std::function<Value(
339+
using GetAccAddressFn = std::function<MemDescOperand(
339340
ConversionPatternRewriter &, Location, int, int, const InstDesc &)>;
340-
using CreateMMAInstFn =
341-
std::function<void(ConversionPatternRewriter &, Location, Value, Value,
342-
Value, Value, Value, const InstDesc &, int, int, int)>;
341+
using CreateMMAInstFn = std::function<void(
342+
ConversionPatternRewriter &, Location, MemDescOperand, MemDescOperand,
343+
Value, Value, Value, const InstDesc &, int, int, int)>;
343344

344345
struct {
345346
unsigned M;
@@ -456,9 +457,9 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
456457
for (int m = 0; m < numRepM; m++) {
457458
for (int n = 0; n < numRepN; n++) {
458459
Value useInitAcc = useDFlag;
459-
Value accAddress = op.getAccAddress(rewriter, loc, m, n, desc);
460+
MemDescOperand accAddress = op.getAccAddress(rewriter, loc, m, n, desc);
460461
for (int k = 0; k < numRepK; k++) {
461-
Value a = aLoader->memLoad(m, k, rewriter, loc);
462+
MemDescOperand a = aLoader->memLoad(m, k, rewriter, loc);
462463
Value b = bLoader.smemLoad(n, k, rewriter, loc);
463464
op.createMMAInst(rewriter, loc, accAddress, a, b, elect, useInitAcc,
464465
desc, m, n, k);
@@ -506,9 +507,10 @@ void convertDot(const LLVMTypeConverter &typeConverter,
506507
};
507508

508509
dot.createMMAInst = [&](ConversionPatternRewriter &rewriter, Location loc,
509-
Value accAddress, Value a, Value b, Value pred,
510-
Value useInitAcc, const DotConversion::InstDesc &desc,
511-
int m, int n, int k) {
510+
MemDescOperand accAddress, MemDescOperand a, Value b,
511+
Value pred, Value useInitAcc,
512+
const DotConversion::InstDesc &desc, int m, int n,
513+
int k) {
512514
Value instDescriptor = createInstDescriptor(
513515
rewriter, op, twoCTAs ? desc.mmaSizeM * 2 : desc.mmaSizeM,
514516
desc.mmaSizeN, desc.transA, desc.transB);
@@ -598,13 +600,14 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
598600
dTensorTy.getElementTypeBitWidth(),
599601
numRows * colSizeInBits);
600602
int blockId = m + n * desc.repShape.numRepM;
601-
return tb.add(baseD, tb.i32_val(numColPerBlock * blockId));
603+
return MemDescOperand{baseD, numColPerBlock * blockId};
602604
};
603605

604606
dot.createMMAInst = [&](ConversionPatternRewriter &rewriter, Location loc,
605-
Value accAddress, Value a, Value b, Value pred,
606-
Value useInitAcc, const DotConversion::InstDesc &desc,
607-
int m, int n, int k) {
607+
MemDescOperand accAddress, MemDescOperand a, Value b,
608+
Value pred, Value useInitAcc,
609+
const DotConversion::InstDesc &desc, int m, int n,
610+
int k) {
608611
auto [numRepM, numRepN, numRepK] = desc.repShape;
609612
int scaleFactorColsPerSet = getScaleFactorColsPerSet(mxfpInstKind);
610613
int numColPerScaleBlockA = ceil<int>(

0 commit comments

Comments
 (0)