Skip to content

Commit 5c02602

Browse files
Merge OpenAI Triton commit 748a47e (#5217)
This PR change the Triton base from 6e390f3 to 748a47e (Sep 24). Pass rate: 96.98%->96.99%
2 parents 86d9987 + c1dcd7b commit 5c02602

File tree

13 files changed

+259
-312
lines changed

13 files changed

+259
-312
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class TargetInfoBase {
9898

9999
virtual bool supportLdMatrix() const { return false; }
100100
virtual bool supportStMatrix() const { return false; }
101+
virtual bool supportLdStMatrixB8() const { return false; }
101102
virtual bool isCuda() const { return false; }
102103

103104
// Annotate target specific information to local load operations during

include/triton/Tools/LinearLayout.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,10 @@ class ColumnAction {
843843
// Inverse of the action
844844
ColumnAction inverse() const;
845845

846+
// Given two permutations self, other seen as functions, returns
847+
// ret(x) = other(self(x))
848+
ColumnAction leftCompose(const ColumnAction &other) const;
849+
846850
static ColumnAction identity(StringAttr inDim, size_t inSizeLog2) {
847851
return ColumnAction(llvm::to_vector(llvm::seq<size_t>(inSizeLog2)), inDim,
848852
inSizeLog2);
@@ -854,6 +858,17 @@ class ColumnAction {
854858
std::string toString() const;
855859
};
856860

861+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
862+
const ColumnAction &action) {
863+
os << action.toString();
864+
return os;
865+
}
866+
867+
inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) {
868+
os << action.toString();
869+
return os;
870+
}
871+
857872
} // namespace mlir::triton
858873

859874
#endif // TRITON_TOOLS_LINEARLAYOUT_H

lib/Tools/LinearLayout.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,17 @@ SmallVector<Value> ColumnAction::apply(ValueRange values) const {
13441344
return ret;
13451345
}
13461346

1347+
ColumnAction ColumnAction::leftCompose(const ColumnAction &other) const {
1348+
assert(inDim == other.inDim);
1349+
assert(inSizeLog2 == other.inSizeLog2);
1350+
assert(action.size() == other.action.size());
1351+
auto newAction = SmallVector<size_t>(action.size());
1352+
for (size_t i = 0; i < action.size(); i++) {
1353+
newAction[i] = action[other.action[i]];
1354+
}
1355+
return ColumnAction(newAction, inDim, inSizeLog2);
1356+
}
1357+
13471358
ColumnAction ColumnAction::inverse() const {
13481359
auto invPerm = SmallVector<size_t>(action.size());
13491360
for (size_t i = 0; i < action.size(); i++) {

python/test/gluon/test_lowerings.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl.
701701
ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16, rank=2),
702702
ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=True, element_bitwidth=16, rank=2),
703703
ttgl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False, element_bitwidth=16, rank=2),
704+
ttgl.NVMMASharedLayout(swizzle_byte_width=32, transposed=False, element_bitwidth=8, rank=2),
704705
ttgl.SwizzledSharedLayout(vec=8, per_phase=1, max_phase=1, order=[1, 0]),
705706
ttgl.SwizzledSharedLayout(vec=4, per_phase=2, max_phase=4, order=[0, 1]),
706707
ttgl.SwizzledSharedLayout(vec=8, per_phase=1, max_phase=8, order=[1, 0]),
@@ -767,13 +768,23 @@ def kernel(x_ptr, y_ptr, shape_tuple: ttgl.constexpr, src_layout: ttgl.constexpr
767768
else:
768769
x = torch.randn(shape, device=device, dtype=torch_dtype)
769770

771+
float8_dtypes = {torch.float8_e5m2}
772+
if hasattr(torch, "float8_e4m3fn"):
773+
float8_dtypes.add(torch.float8_e4m3fn)
774+
775+
def _assert_close(actual, expected):
776+
if actual.dtype in float8_dtypes:
777+
torch.testing.assert_close(actual.to(torch.float16), expected.to(torch.float16), rtol=0, atol=0)
778+
else:
779+
torch.testing.assert_close(actual, expected)
780+
770781
y = torch.zeros_like(x)
771782
kernel[(1, )](x, y, shape, blocked_layout, dist_layout, shared_layout, num_warps=num_warps)
772-
torch.testing.assert_close(y, x)
783+
_assert_close(y, x)
773784

774785
y = torch.zeros_like(x)
775786
obj = kernel[(1, )](x, y, shape, dist_layout, blocked_layout, shared_layout, num_warps=num_warps)
776-
torch.testing.assert_close(y, x)
787+
_assert_close(y, x)
777788
if (is_cuda() and isinstance(shared_layout, ttgl.NVMMASharedLayout) and dist_layout in _ld_st_mma_layouts
778789
and dist_layout.version[0] >= 3 and dtype == "float16"):
779790
assert "stmatrix" in obj.asm["ptx"]

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
509509
"""
510510
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
511511
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
512+
if not isinstance(default_args, tuple):
513+
default_args = (default_args, )
514+
if not isinstance(worker_args, tuple):
515+
worker_args = (worker_args, )
512516
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
513517
worker_num_regs, _generator)
514518

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,6 @@ llvm.func @cluster_id() -> i32 {
1313

1414
// -----
1515

16-
// CHECK-LABEL: @ldmatrix
17-
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
18-
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
19-
%0 = nvgpu.ldmatrix %ptr, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
20-
// CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4];
21-
%1 = nvgpu.ldmatrix %ptr, m8n8, 16 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
22-
// CHECK: ldmatrix.sync.aligned.m16n16.x4.trans.shared.b8 {$0, $1, $2, $3}, [$4];
23-
%l = nvgpu.ldmatrix %ptr, m16n16, 8 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
24-
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)>
25-
%3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)>
26-
%4 = llvm.extractvalue %l[0] : !llvm.struct<(i32, i32, i32, i32)>
27-
%5 = llvm.insertvalue %4, %3[1] : !llvm.struct<(i32, i32, i32, i32)>
28-
llvm.return %5 : !llvm.struct<(i32, i32, i32, i32)>
29-
}
30-
31-
// -----
32-
3316
!struct_128xf32 = !llvm.struct<(
3417
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
3518
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -888,11 +888,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
888888
// CHECK: llvm.mlir.addressof @global_smem
889889
// CHECK: llvm.store {{.*}} vector<4xi32>
890890
// CHECK: nvvm.bar.warp.sync
891-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
891+
// CHECK: nvvm.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
892892
// CHECK: nvvm.bar.warp.sync
893893
// CHECK: llvm.store {{.*}} vector<4xi32>
894894
// CHECK: nvvm.bar.warp.sync
895-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
895+
// CHECK: nvvm.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
896896
%0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1>
897897
tt.return
898898
}
@@ -911,9 +911,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
911911
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
912912
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
913913
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
914-
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915-
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
916-
// CHECK-NOT: nvgpu.ldmatrix
914+
// CHECK: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915+
// CHECK: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<col>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
916+
// CHECK-NOT: nvvm.ldmatrix
917917
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
918918
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
919919
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -941,9 +941,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
941941
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
942942
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
943943
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
944-
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
945-
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
946-
// CHECK-NOT: nvgpu.ldmatrix
944+
// CHECK: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
945+
// CHECK: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<col>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
946+
// CHECK-NOT: nvvm.ldmatrix
947947
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
948948
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
949949
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -971,7 +971,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
971971
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
972972
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
973973
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
974-
// CHECK-NOT: nvgpu.ldmatrix
974+
// CHECK-NOT: nvvm.ldmatrix
975975
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
976976
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
977977
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -999,7 +999,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
999999
tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) {
10001000
%AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
10011001
%BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
1002-
// CHECK-COUNT-32: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
1002+
// CHECK-COUNT-32: nvvm.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
10031003
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a>
10041004
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b>
10051005
%cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
@@ -1023,8 +1023,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10231023
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
10241024
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
10251025
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
1026-
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
1027-
// CHECK-NOT: nvgpu.ldmatrix
1026+
// CHECK: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 2 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
1027+
// CHECK-NOT: nvvm.ldmatrix
10281028
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
10291029
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
10301030
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -1355,7 +1355,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13551355
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
13561356
%a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) {
13571357
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
1358-
// CHECK: nvgpu.ldmatrix
1358+
// CHECK: nvvm.ldmatrix
13591359
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
13601360
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
13611361

@@ -1431,9 +1431,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14311431
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
14321432
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
14331433
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
1434-
// CHECK: nvgpu.ldmatrix
1434+
// CHECK: nvvm.ldmatrix
14351435
// CHECK-SAME: (i32, i32, i32, i32)
1436-
// CHECK: nvgpu.ldmatrix
1436+
// CHECK: nvvm.ldmatrix
14371437
// CHECK-SAME: (i32, i32, i32, i32)
14381438
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
14391439
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
@@ -1936,8 +1936,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
19361936
%f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
19371937
%i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem>
19381938

1939-
// CHECK: nvgpu.ldmatrix
1940-
// CHECK: nvgpu.ldmatrix
1939+
// CHECK: nvvm.ldmatrix
1940+
// CHECK: nvvm.ldmatrix
19411941

19421942
%f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
19431943
%i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b>

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,35 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num
750750

751751
// -----
752752

753+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}>
754+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
755+
#smem = #ttg.shared_memory
756+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
757+
// CHECK-LABEL: lower_ldmatrix_trans_b8
758+
tt.func @lower_ldmatrix_trans_b8(%A: !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem, mutable, 1x128x64>) {
759+
%0 = ttg.local_load %A : !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem, mutable, 1x128x64> -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
760+
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b8>, layout = #nvvm.mma_layout<col>{{.*}}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
761+
tt.return
762+
}
763+
}
764+
765+
// -----
766+
767+
#linear3 = #ttg.linear<{register = [[0, 0, 0, 1, 0], [0, 0, 0, 0, 8], [0, 0, 0, 8, 0], [0, 0, 0, 0, 16], [0, 0, 0, 0, 128]], lane = [[0, 0, 0, 2, 0], [0, 0, 0, 4, 0], [0, 0, 0, 0, 1], [0, 0, 0, 0, 2], [0, 0, 0, 0, 4]], warp = [[0, 0, 0, 0, 32], [0, 0, 0, 0, 64]], block = []}>
768+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1, 1], CTAOrder = [4, 3, 2, 1, 0]}>
769+
#smem = #ttg.shared_memory
770+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
771+
// CHECK-LABEL: @stmatrix_b8_trans_linear
772+
tt.func public @stmatrix_b8_trans_linear(%data: tensor<1x1x1x16x256xf8E4M3FN, #linear3>) {
773+
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b8>, layout = #nvvm.mma_layout<col>{{.*}}} : !llvm.ptr<3>, i32, i32, i32, i32
774+
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1x1x1x16x256xf8E4M3FN, #shared, #smem, mutable>
775+
ttg.local_store %data, %0 : tensor<1x1x1x16x256xf8E4M3FN, #linear3> -> !ttg.memdesc<1x1x1x16x256xf8E4M3FN, #shared, #smem, mutable>
776+
tt.return
777+
}
778+
}
779+
780+
// -----
781+
753782
#bm64_bn128 = #ttng.tensor_memory_encoding<blockM = 64, blockN = 128, colStride = 1>
754783
#bm64_bn64 = #ttng.tensor_memory_encoding<blockM = 64, blockN = 64, colStride = 1>
755784

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
205205
tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) {
206206
// CHECK-COUNT-8: llvm.store
207207
// CHECK: nvvm.barrier0
208-
// CHECK-COUNT-8: nvgpu.ldmatrix
208+
// CHECK-COUNT-8: nvvm.ldmatrix
209209
// CHECK: nvvm.barrier0
210210
// CHECK-COUNT-8: llvm.store
211211
// CHECK: nvvm.barrier0
212-
// CHECK-COUNT-8: nvgpu.ldmatrix
212+
// CHECK-COUNT-8: nvvm.ldmatrix
213213
%c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked>
214214
tt.return
215215
}
@@ -225,19 +225,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
225225
tt.func @convert_blocked_to_dot_rhs(%a: tensor<64x64xf16, #blocked>) {
226226
// CHECK-COUNT-1: llvm.store
227227
// CHECK: nvvm.barrier0
228-
// CHECK-COUNT-4: nvgpu.ldmatrix
228+
// CHECK-COUNT-4: nvvm.ldmatrix
229229
// CHECK: nvvm.barrier0
230230
// CHECK-COUNT-1: llvm.store
231231
// CHECK: nvvm.barrier0
232-
// CHECK-COUNT-4: nvgpu.ldmatrix
232+
// CHECK-COUNT-4: nvvm.ldmatrix
233233
// CHECK: nvvm.barrier0
234234
// CHECK-COUNT-1: llvm.store
235235
// CHECK: nvvm.barrier0
236-
// CHECK-COUNT-4: nvgpu.ldmatrix
236+
// CHECK-COUNT-4: nvvm.ldmatrix
237237
// CHECK: nvvm.barrier0
238238
// CHECK-COUNT-1: llvm.store
239239
// CHECK: nvvm.barrier0
240-
// CHECK-COUNT-4: nvgpu.ldmatrix
240+
// CHECK-COUNT-4: nvvm.ldmatrix
241241
%b = ttg.convert_layout %a : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
242242
tt.return
243243
}

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,26 +105,6 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
105105
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
106106
}
107107

108-
def NVGPU_LoadMatrixShapeAttr : I32EnumAttr<
109-
"LoadMatrixShape", "",
110-
[
111-
I32EnumAttrCase<"m8n8", 0, "m8n8">,
112-
I32EnumAttrCase<"m16n16", 1, "m16n16">
113-
]> {
114-
let cppNamespace = "::mlir::triton::nvgpu";
115-
}
116-
117-
def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> {
118-
let arguments = (
119-
ins LLVM_PointerShared:$addr,
120-
NVGPU_LoadMatrixShapeAttr:$shape,
121-
I32Attr:$bit_width,
122-
UnitAttr:$trans
123-
);
124-
let results = (outs AnyTypeOf<[LLVM_AnyStruct, I32]>:$result);
125-
let assemblyFormat = "$addr `,` $shape `,` $bit_width attr-dict `:` functional-type($addr, $result)";
126-
}
127-
128108
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
129109
let results = (outs I32:$result);
130110
let assemblyFormat = "attr-dict";

0 commit comments

Comments
 (0)