Skip to content

Commit dfcdc27

Browse files
authored
[Backend] Refactor MMAv5 lowering and put mma_scaled in an if (triton-lang#6478)
This PR refactors the lowering of MMAv5 to share more of the code between tc_gen5_mma and tc_gen5_mma_scaled, while also applying the same optimization that tc_gen5_mma has that places the `tcgen05.mma` instructions in an if block. This was previously checked to improve performance.
1 parent aac457e commit dfcdc27

File tree

4 files changed

+302
-259
lines changed

4 files changed

+302
-259
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
197197
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
198198
// CHECK-LABEL: @tc_gen5_mma_block_scale
199199
// CHECK-SAME: (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[USE_ACC:.+]]: i1, %{{.*}}: i1, %{{.*}})
200-
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
201-
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
202-
// CHECK-DAG: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32
200+
// CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
201+
// CHECK: %[[WID:.+]] = nvgpu.warp_id
202+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
203+
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
204+
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
205+
// CHECK: llvm.cond_br %[[P1]]
203206
// CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
204207
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144708608 : i32) : i32
205208
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %[[USE_ACC]]

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class DotOpMmaMemLoader {
2828
public:
2929
virtual ~DotOpMmaMemLoader() = default;
3030
virtual Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
31-
Location loc) = 0;
31+
Location loc) const = 0;
3232
};
3333

3434
// Helper class to load shared memory slices following MMAv3 layout.
@@ -44,10 +44,10 @@ class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader {
4444
// Return a descriptor pointing to the shared memory slice at coordinates (a,
4545
// b)
4646
Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter,
47-
Location loc);
47+
Location loc) const;
4848

4949
Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
50-
Location loc) override {
50+
Location loc) const override {
5151
return smemLoad(a, b, rewriter, loc);
5252
}
5353

@@ -74,10 +74,10 @@ class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
7474
SmallVector<unsigned int> instrShape, bool interleaved,
7575
bool trans);
7676
Value tmemLoad(int a, int b, ConversionPatternRewriter &rewriter,
77-
Location loc);
77+
Location loc) const;
7878

7979
Value memLoad(int a, int b, ConversionPatternRewriter &rewriter,
80-
Location loc) override {
80+
Location loc) const override {
8181
return tmemLoad(a, b, rewriter, loc);
8282
}
8383

0 commit comments

Comments
 (0)