Skip to content

Commit 3e6c7f8

Browse files
committed
[NVPTX] Moved unsupported MLIR MMA tests to invalid.mlir. PR156040.
1 parent 2b34832 commit 3e6c7f8

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,34 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
743743

744744
// -----
745745

746+
// f32 return type, f16 accumulate type
747+
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
748+
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
749+
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
750+
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
751+
// C and D should have the same type according to PTX ISA
752+
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
753+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
754+
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
755+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
756+
}
757+
758+
// -----
759+
760+
// f16 return type, f32 accumulate type
761+
llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
762+
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
763+
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
764+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
765+
// C and D should have the same type according to PTX ISA
766+
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
767+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
768+
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
769+
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
770+
}
771+
772+
// -----
773+
746774
func.func @atomicrmw_mismatched_operands(%f32_ptr : !llvm.ptr, %f32 : f32) {
747775
// expected-error@+1 {{op failed to verify that result #0 and operand #1 have the same type}}
748776
%0 = "llvm.atomicrmw"(%f32_ptr, %f32) {bin_op=11, ordering=1} : (!llvm.ptr, f32) -> i32

0 commit comments

Comments
 (0)