@@ -27,13 +27,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2727 tt.return
2828 }
2929
30- // CHECK-LABEL: wmma1_dot
31- tt.func @wmma1_dot (%arg0: tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>>, %arg1: tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>>, %arg2: tensor <16 x16 xf16 , #mma1 >) {
30+ // CHECK-LABEL: wmma1_dot_f16
31+ tt.func @wmma1_dot_f16 (%arg0: tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>>, %arg1: tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>>, %arg2: tensor <16 x16 xf16 , #mma1 >) {
3232 // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
3333 // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
3434 // CHECK: llvm.mlir.undef : vector<16xf16>
3535 // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16>
36- // CHECK: rocdl. wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
36+ // CHECK: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
3737 %0 = tt.dot %arg0 , %arg1 , %arg2 , inputPrecision = ieee : tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <16 x16 xf16 , #mma1 >
3838 // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
3939 // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
@@ -50,11 +50,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
5050 // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
5151 // CHECK: llvm.mlir.undef : vector<16xbf16>
5252 // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16>
53- // CHECK: rocdl. wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
53+ // CHECK: wmma.bf16.16x16x16.bf16{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
5454 %0 = tt.dot %arg0 , %arg1 , %arg2 , inputPrecision = ieee : tensor <16 x16 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <16 x16 xbf16 , #mma1 >
5555 tt.return
5656 }
5757
58+ // CHECK-LABEL: wmma1_dot_f16_tied
59+ tt.func @wmma1_dot_f16_tied (%arg0: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>>, %arg1: tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>>, %arg2: tensor <64 x16 xf16 , #mma1 >) {
60+ // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
61+ // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
62+ // CHECK: llvm.mlir.undef : vector<16xf16>
63+ // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xf16>
64+ // CHECK-COUNT-2: wmma.f16.16x16x16.f16.tied{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
65+ %0 = tt.dot %arg0 , %arg1 , %arg2 , inputPrecision = ieee : tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <64 x16 xf16 , #mma1 >
66+ // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
67+ // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
68+ // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
69+ tt.return
70+ }
71+
72+ // CHECK-LABEL: wmma1_dot_bf16_tied
73+ tt.func @wmma1_dot_bf16_tied (%arg0: tensor <64 x16 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>>, %arg1: tensor <16 x16 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>>, %arg2: tensor <64 x16 xbf16 , #mma1 >) {
74+ // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
75+ // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
76+ // CHECK: llvm.mlir.undef : vector<16xbf16>
77+ // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xbf16>
78+ // CHECK-COUNT-2: wmma.bf16.16x16x16.bf16.tied{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
79+ %0 = tt.dot %arg0 , %arg1 , %arg2 , inputPrecision = ieee : tensor <64 x16 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <64 x16 xbf16 , #mma1 >
80+ // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xbf16>
81+ // CHECK: llvm.mlir.undef : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
82+ // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
83+ tt.return
84+ }
85+
5886 // CHECK-LABEL: wmma1_dot_int8_32
5987 tt.func @wmma1_dot_int8_32 (%arg0: tensor <16 x16 xi8 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>>, %arg1: tensor <16 x16 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>>, %arg2: tensor <16 x16 xi32 , #mma1 >) {
6088 // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
@@ -64,7 +92,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6492 // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8>
6593 // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
6694 // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
67- // CHECK: rocdl. wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
95+ // CHECK: wmma.i32.16x16x16.iu8{{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
6896 %0 = tt.dot %arg0 , %arg1 , %arg2 {inputPrecision = 2 : i32 , maxNumImpreciseAcc = 0 : i32 } : tensor <16 x16 xi8 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <16 x16 xi32 , #mma1 >
6997 // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
7098 tt.return
@@ -79,7 +107,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
79107 // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
80108 // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
81109 // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
82- // CHECK: rocdl. wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
110+ // CHECK: wmma.i32.16x16x16.iu4{{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
83111 %0 = tt.dot %arg0 , %arg1 , %arg2 {inputPrecision = 2 : i32 , maxNumImpreciseAcc = 0 : i32 } : tensor <16 x16 xi4 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <16 x16 xi4 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <16 x16 xi32 , #mma1 >
84112 // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
85113 tt.return
@@ -196,7 +224,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
196224 // CHECK-COUNT-32: llvm.insertelement
197225 // CHECK-COUNT-8: llvm.extractvalue %arg2
198226 // CHECK-COUNT-8: llvm.insertelement
199- // CHECK-COUNT-2: rocdl. wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
227+ // CHECK-COUNT-2: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
200228 %0 = tt.dot %arg0 , %arg1 , %arg2 , inputPrecision = ieee : tensor <2 x16 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 16 }>> * tensor <2 x32 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma1 , kWidth = 16 }>> -> tensor <2 x16 x16 xf16 , #mma1 >
201229 // CHECK-COUNT-8: llvm.extractelement
202230 // CHECK-COUNT-8: llvm.insertvalue
0 commit comments