Skip to content

Commit 0560390

Browse files
authored
[NVIDIA] Replace some NVGPU ops with equivalent NVVM ops (part 2) (#7471)
This change replaces some NVGPU ops with the corresponding NVVM ops. It aligns with previous discussions in PR #7420. For some op like NVGPU::FenceAsyncSharedOp, there is no corresponding Intrinsic, and LLVM will also generate PTX. However, in the long run, I think it is better to hand over the responsibility of generating code to LLVM instead of hard coding PTX at the NVGPU layer. The ConvertNVVMToLLVMPass has been added to the pipeline and build system so that NVVM ops are correctly lowered to LLVM IR.
1 parent 55cea61 commit 0560390

File tree

12 files changed

+47
-248
lines changed

12 files changed

+47
-248
lines changed

lib/Target/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_triton_library(TritonLLVMIR
1313
MLIRIndexToLLVM
1414
MLIRIR
1515
MLIRLLVMDialect
16+
MLIRNVVMToLLVM
1617
MLIRLLVMToLLVMIRTranslation
1718
MLIRNVVMToLLVMIRTranslation
1819
MLIRROCDLToLLVMIRTranslation

python/src/passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ void init_triton_passes_convert(py::module &&m) {
9999
ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass);
100100
ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass);
101101
ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass);
102+
ADD_PASS_WRAPPER_0("add_nvvm_to_llvm", createConvertNVVMToLLVMPass);
102103
}
103104

104105
void init_triton_passes_llvmir(py::module &&m) {

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
609609
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_hopper():
610610
# TODO: The use of stmatrix for Blackwell is currently not supported.
611611
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
612-
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
612+
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
613+
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]
613614

614615

615616
@triton.jit
@@ -1668,4 +1669,5 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
16681669
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda() and is_hopper():
16691670
# TODO: The use of stmatrix for Blackwell is currently not supported.
16701671
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
1671-
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
1672+
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
1673+
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
// RUN: triton-opt %s --convert-nv-gpu-to-llvm -allow-unregistered-dialect -split-input-file | FileCheck %s
22

3-
// CHECK-LABEL: @nvvm_syncs
4-
llvm.func @nvvm_syncs() {
5-
// CHECK: fence.proxy.async.shared::cta;
6-
nvgpu.fence_async_shared {bCluster = false}
7-
// CHECK: fence.proxy.async.shared::cluster;
8-
nvgpu.fence_async_shared {bCluster = true}
9-
10-
llvm.return
11-
}
12-
133
// CHECK-LABEL: @cluster_id
144
llvm.func @cluster_id() -> i32 {
155
// CHECK: %cluster_ctaid.x;
@@ -23,30 +13,6 @@ llvm.func @cluster_id() -> i32 {
2313

2414
// -----
2515

26-
// CHECK-LABEL: @stmatrix
27-
llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) {
28-
// CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};
29-
nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32
30-
// CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4};
31-
nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32
32-
llvm.return
33-
}
34-
35-
// -----
36-
37-
// CHECK-LABEL: @ldmatrix
38-
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
39-
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
40-
%0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
41-
// CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4];
42-
%1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
43-
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)>
44-
%3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)>
45-
llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)>
46-
}
47-
48-
// -----
49-
5016
!struct_128xf32 = !llvm.struct<(
5117
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
5218
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
880880
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
881881
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
882882
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
883-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
884-
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
885-
// CHECK-NOT: nvgpu.ldmatrix
883+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
884+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
885+
// CHECK-NOT: nvvm.ldmatrix
886886
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
887887
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
888888
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -910,9 +910,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
910910
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
911911
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
912912
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
913-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
914-
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915-
// CHECK-NOT: nvgpu.ldmatrix
913+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
914+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915+
// CHECK-NOT: nvvm.ldmatrix
916916
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
917917
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
918918
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -940,7 +940,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
940940
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
941941
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
942942
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
943-
// CHECK-NOT: nvgpu.ldmatrix
943+
// CHECK-NOT: nvvm.ldmatrix
944944
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
945945
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
946946
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -968,7 +968,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
968968
tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) {
969969
%AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
970970
%BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
971-
// CHECK-COUNT-32: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
971+
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
972+
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
972973
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a>
973974
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b>
974975
%cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
@@ -992,8 +993,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
992993
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
993994
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
994995
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
995-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
996-
// CHECK-NOT: nvgpu.ldmatrix
996+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
997+
// CHECK-NOT: nvvm.ldmatrix
997998
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
998999
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
9991000
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -1308,7 +1309,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13081309
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
13091310
%a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) {
13101311
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
1311-
// CHECK: nvgpu.ldmatrix
1312+
// CHECK: nvvm.ldmatrix
13121313
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
13131314
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
13141315

@@ -1384,9 +1385,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13841385
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
13851386
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
13861387
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
1387-
// CHECK: nvgpu.ldmatrix
1388+
// CHECK: nvvm.ldmatrix
13881389
// CHECK-SAME: (i32, i32, i32, i32)
1389-
// CHECK: nvgpu.ldmatrix
1390+
// CHECK: nvvm.ldmatrix
13901391
// CHECK-SAME: (i32, i32, i32, i32)
13911392
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
13921393
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
@@ -1859,8 +1860,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
18591860
%f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
18601861
%i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem>
18611862

1862-
// CHECK: nvgpu.ldmatrix
1863-
// CHECK: nvgpu.ldmatrix
1863+
// CHECK: nvvm.ldmatrix
1864+
// CHECK: nvvm.ldmatrix
18641865

18651866
%f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
18661867
%i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b>

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
203203
// CHECK-LABEL: convert_mma_to_blocked
204204
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
205205
tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) {
206-
// CHECK-COUNT-16: nvgpu.stmatrix
206+
// CHECK-COUNT-16: nvvm.stmatrix
207207
// CHECK: nvvm.barrier0
208208
%c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked>
209209
tt.return
@@ -254,7 +254,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
254254
// CHECK-LABEL: distribute_to_shared_st_matrix
255255
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
256256
tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) {
257-
// CHECK-COUNT-16: nvgpu.stmatrix
257+
// CHECK-COUNT-16: nvvm.stmatrix
258258
// CHECK: llvm.return
259259
%b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
260260
tt.return
@@ -269,7 +269,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
269269
// CHECK-LABEL: distribute_to_shared_st_matrix_local_store
270270
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
271271
tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) {
272-
// CHECK-COUNT-16: nvgpu.stmatrix
272+
// CHECK-COUNT-16: nvvm.stmatrix
273273
// CHECK: llvm.return
274274
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
275275
ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
@@ -285,7 +285,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
285285
// CHECK-LABEL: distribute_to_shared_st_matrix_local_store
286286
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
287287
tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<64x128xf16, #linear>) {
288-
// CHECK-COUNT-8: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
288+
// CHECK-COUNT-8: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
289289
// CHECK: llvm.return
290290
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
291291
ttg.local_store %a, %b : tensor<64x128xf16, #linear> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
@@ -301,7 +301,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
301301
// CHECK-LABEL: distribute_to_swizzled_st_matrix_local_store
302302
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
303303
tt.func @distribute_to_swizzled_st_matrix_local_store(%a: tensor<8x64xf16, #mma>) {
304-
// CHECK-COUNT-2: nvgpu.stmatrix
304+
// CHECK-COUNT-2: nvvm.stmatrix
305305
// CHECK: llvm.return
306306
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
307307
ttg.local_store %a, %b : tensor<8x64xf16, #mma> -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
@@ -317,7 +317,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
317317
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
318318
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
319319
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<64x32xf16, #linear>) {
320-
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
320+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
321321
// CHECK: llvm.return
322322
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
323323
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
@@ -339,7 +339,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
339339
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
340340
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
341341
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<32x32xf16, #linear>) {
342-
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
342+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
343343
// CHECK: llvm.return
344344
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
345345
ttg.local_store %a, %b : tensor<32x32xf16, #linear> -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
@@ -355,7 +355,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
355355
// CHECK-LABEL: linear_to_swizzled_st_matrix_x2_local_store_fp8
356356
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
357357
tt.func @linear_to_swizzled_st_matrix_x2_local_store_fp8(%a: tensor<64x16xf8E4M3FNUZ, #linear>) {
358-
// CHECK-COUNT-1: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}} :
358+
// CHECK-COUNT-1: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>} :
359359
// CHECK: llvm.return
360360
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
361361
ttg.local_store %a, %b : tensor<64x16xf8E4M3FNUZ, #linear> -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
@@ -371,7 +371,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
371371
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store_fp32
372372
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
373373
tt.func @linear_to_swizzled_st_matrix_local_store_fp32(%a: tensor<64x16xf32, #linear>) {
374-
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
374+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
375375
// CHECK: llvm.return
376376
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
377377
ttg.local_store %a, %b : tensor<64x16xf32, #linear> -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
@@ -388,7 +388,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
388388
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
389389
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
390390
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<64x32xf16, #linear>) {
391-
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
391+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
392392
// CHECK: llvm.return
393393
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
394394
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
@@ -410,7 +410,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
410410
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
411411
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
412412
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<16x32xf16, #linear>) {
413-
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
413+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
414414
// CHECK: llvm.return
415415
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
416416
ttg.local_store %a, %b : tensor<16x32xf16, #linear> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def make_llir(self, src, metadata, options, capability):
351351
passes.common.add_canonicalizer(pm)
352352
passes.common.add_cse(pm)
353353
passes.common.add_symbol_dce(pm)
354+
passes.convert.add_nvvm_to_llvm(pm)
354355
if not knobs.compilation.disable_line_info:
355356
passes.llvmir.add_di_scope(pm)
356357
pm.run(mod)

0 commit comments

Comments
 (0)