Skip to content

Commit cfe3dd0

Browse files
authored
[Backend][NFC] Switch some inline PTX to NVVM ops/intrinsics (#7725)
1 parent 1a1262f commit cfe3dd0

File tree

19 files changed

+283
-330
lines changed

19 files changed

+283
-330
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/MLIRTypes.h"
55

66
namespace mlir::triton {
7+
enum class ProgramIDDim : uint32_t;
78

89
class TargetInfoBase {
910
public:
@@ -48,7 +49,7 @@ class TargetInfoBase {
4849
Value i) const = 0;
4950

5051
virtual Value programId(RewriterBase &rewriter, Location loc,
51-
ModuleOp moduleOp, int axis) const = 0;
52+
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5253

5354
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
5455
SmallVector<Value> &acc, triton::ReduceOp op,

lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
2626
ConversionPatternRewriter &rewriter) const override {
2727
auto loc = op->getLoc();
2828

29-
auto getPid = [&](int axis) {
30-
return targetInfo.programId(rewriter, loc,
31-
op->getParentOfType<ModuleOp>(), axis);
32-
};
33-
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
29+
std::array<Value, 3> pid;
30+
auto module = op->getParentOfType<ModuleOp>();
31+
for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z})
32+
pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis);
3433

3534
// Simple printf of a string without any tensors.
3635
if (op.getNumOperands() == 0) {

lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ struct GetProgramIdOpConversion
1717
LogicalResult
1818
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
1919
ConversionPatternRewriter &rewriter) const override {
20-
Value programId = targetInfo.programId(rewriter, op->getLoc(),
21-
op->getParentOfType<ModuleOp>(),
22-
op.getAxisAsInt());
20+
Value programId = targetInfo.programId(
21+
rewriter, op->getLoc(), op->getParentOfType<ModuleOp>(), op.getAxis());
2322
rewriter.replaceOp(op, programId);
2423
return success();
2524
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
508508
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
509509
// CHECK-LABEL: basic_program_id
510510
tt.func @basic_program_id() {
511-
// CHECK: llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
511+
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
512512
%0 = tt.get_program_id x : i32
513513
tt.return
514514
}
@@ -2089,7 +2089,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
20892089
// CHECK-LABEL: @tensor_memory_st
20902090
// CHECK: nvgpu.tensor_memory_base
20912091
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
2092-
// CHECK: tcgen05.wait::st.sync.aligned
2092+
// CHECK: nvvm.tcgen05.wait <store>
20932093
tt.func public @tensor_memory_st(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
20942094
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
20952095
%0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
107107
// CHECK-LABEL: @tensor_memory_ld
108108
// CHECK: nvgpu.tensor_memory_base
109109
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
110-
// CHECK: tcgen05.wait::st.sync.aligned
110+
// CHECK: nvvm.tcgen05.wait <store>
111111
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x128.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127}, [$128 + 0];", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
112-
// CHECK: tcgen05.wait::ld.sync.aligned
112+
// CHECK: nvvm.tcgen05.wait <load>
113113
tt.func public @tensor_memory_ld(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
114114
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
115115
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
@@ -158,10 +158,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
158158
// CHECK: nvgpu.tensor_memory_base
159159
// CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32
160160
// CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32
161-
// CHECK: tcgen05.wait::st.sync.aligned
161+
// CHECK: nvvm.tcgen05.wait <store>
162162
// CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32
163163
// CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32
164-
// CHECK: tcgen05.wait::ld.sync.aligned
164+
// CHECK: nvvm.tcgen05.wait <load>
165165
tt.func public @tensor_memory_ld_m64(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
166166
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
167167
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
@@ -179,9 +179,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
179179
// CHECK-LABEL: @tensor_memory_unpack_f16
180180
// CHECK: nvgpu.tensor_memory_base
181181
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32
182-
// CHECK: tcgen05.wait::st.sync.aligned
182+
// CHECK: nvvm.tcgen05.wait <store>
183183
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63}, [$64 + 0];", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
184-
// CHECK: tcgen05.wait::ld.sync.aligned
184+
// CHECK: nvvm.tcgen05.wait <load>
185185
tt.func public @tensor_memory_unpack_f16() {
186186
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked1>
187187
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>
@@ -388,10 +388,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
388388
// CHECK-LABEL: @tensor_memory_ld_128x256
389389
// CHECK-COUNT-4: tcgen05.st.sync.aligned.32x32b.x64.b32
390390
// CHECK-NOT: tcgen05.st
391-
// CHECK: tcgen05.wait::st.sync.aligned
391+
// CHECK: nvvm.tcgen05.wait <store>
392392
// CHECK-COUNT-4: tcgen05.ld.sync.aligned.32x32b.x64.b32
393393
// CHECK-NOT: tcgen05.ld
394-
// CHECK: tcgen05.wait::ld.sync.aligned
394+
// CHECK: nvvm.tcgen05.wait <load>
395395
tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
396396
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
397397
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
@@ -408,9 +408,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
408408
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
409409
// CHECK-LABEL: @tensor_memory_ld_128x256_8_warps
410410
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
411-
// CHECK: tcgen05.wait::st.sync.aligned
411+
// CHECK: nvvm.tcgen05.wait <store>
412412
// CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32
413-
// CHECK: tcgen05.wait::ld.sync.aligned
413+
// CHECK: nvvm.tcgen05.wait <load>
414414
tt.func public @tensor_memory_ld_128x256_8_warps(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
415415
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
416416
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>

0 commit comments

Comments
 (0)