From effe00fb825e080ba57117166d87126d84d16ad1 Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Fri, 24 Jan 2025 20:07:30 -0800 Subject: [PATCH 1/4] add 2:4 sparsity matrix ops to rocdl --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 95fbe7ed66a43..fe2036c246cc5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -408,6 +408,24 @@ def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">; def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">; def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>; def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>; + +// 2:4 Sparsity ops (GFX940) +def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">; +def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">; +def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">; +def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x23x16.bf16">; +def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.i8">; +def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.i8">; +def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">; +def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">; +def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; +def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; +def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">; +def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">; +def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fb8">; +def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">; + + //===---------------------------------------------------------------------===// // WMMA intrinsics class ROCDL_Wmma_IntrOp overloadedOperands, From bc26b45d2e5519ea6f940e2566b39cb0320be549 Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Fri, 24 Jan 2025 21:14:49 -0800 Subject: [PATCH 2/4] wip on lit tests --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 10 +++--- mlir/test/Target/LLVMIR/rocdl.mlir | 32 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index fe2036c246cc5..974712c581537 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -413,16 +413,16 @@ def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32. def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">; def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">; def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">; -def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x23x16.bf16">; -def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.i8">; -def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.i8">; +def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.bf16">; +def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">; +def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">; def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">; def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">; -def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; +def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">; def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">; def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">; -def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fb8">; +def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">; def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">; diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index b74edb6210683..882b7f205f848 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -395,6 +395,38 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r34 = rocdl.smfmac.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + (vector<4xf16>, vector<8xf16>, vector<4xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %a, <8 x half> %b, <16 x float> %c, i32 %idx, i32 0, i32 0) + %r35 = rocdl.smfmac.f32.32x32x16.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + (vector<4xf16>, vector<8xf16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %a, <8 x i16> %b, <4 x float> %c, i32 %idx, i32 0, i32 0) + %r36 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + (vector<4xi16>, vector<8xi16>, vector<4xi16>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %a, <8 x i16> %b, <16 x float> %c, i32 %idx, i32 0, i32 0) + %r37 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + (vector<4xi16>, vector<8xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + +//def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">; +//def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">; +//def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">; +//def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">; +//def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">; +//def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; +//def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">; +//def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">; +//def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">; +//def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">; + llvm.return %r0 : vector<32 x f32> } From 4b2d023d9e3bce523d104ec9108ab3d54aff0b82 Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Sat, 25 Jan 2025 15:20:22 -0800 Subject: [PATCH 3/4] fix tests --- mlir/test/Target/LLVMIR/rocdl.mlir | 99 +++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 21 deletions(-) diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 882b7f205f848..326bd3ae6b6f8 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -395,41 +395,98 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %r34 = rocdl.smfmac.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + llvm.return %r0 : vector<32 x f32> +} + +llvm.func @rocdl.smfmac(%arg0 : i32, + %arg1 : vector<4 x f16>, + %arg2 : vector<8 x f16>, + %arg3 : vector<4 x f32>, + %arg4 : vector<16 x f32>, + %arg5 : vector<4 x i16>, + %arg6 : vector<8 x i16>, + %arg7 : vector<2xi32>, + %arg8 : vector<4xi32>, + %arg9 : vector<16xi32>) -> vector<4 x f32> { + %csti32 = llvm.mlir.constant(42 : i32) : i32 + + // CHECK-LABEL: rocdl.smfmac + + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 : (vector<4xf16>, vector<8xf16>, vector<4xf32>, - i32, i32, i32) -> vector<16xf32> + i32, i32, i32) -> vector<4xf32> - // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %a, <8 x half> %b, <16 x float> %c, i32 %idx, i32 0, i32 0) - %r35 = rocdl.smfmac.f32.32x32x16.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %a, <8 x i16> %b, <4 x float> %c, i32 %idx, i32 0, i32 0) - %r36 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : - (vector<4xi16>, vector<8xi16>, vector<4xi16>, + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 : + (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %a, <8 x i16> %b, <16 x float> %c, i32 %idx, i32 0, i32 0) - %r37 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42) + %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xi32>, + i32, i32, i32) -> vector<4xi32> -//def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">; -//def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">; -//def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">; -//def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">; -//def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">; -//def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">; -//def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">; -//def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">; -//def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">; -//def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">; + // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42) + %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xi32>, + i32, i32, i32) -> vector<16xi32> - llvm.return %r0 : vector<32 x f32> + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + + // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42) + %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + llvm.return %r0 : vector<4 x f32> } + llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32, %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>, %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> { From 55327c3a95985253d531c05ec9254495d293c1c7 Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Sat, 25 Jan 2025 15:51:19 -0800 Subject: [PATCH 4/4] more mlir tests --- mlir/test/Dialect/LLVMIR/rocdl.mlir | 87 +++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 712f8c2a1caf6..5186e43398f01 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, llvm.return } + +llvm.func @rocdl.smfmac(%arg0 : i32, + %arg1 : vector<4 x f16>, + %arg2 : vector<8 x f16>, + %arg3 : vector<4 x f32>, + %arg4 : vector<16 x f32>, + %arg5 : vector<4 x i16>, + %arg6 : vector<8 x i16>, + %arg7 : vector<2xi32>, + %arg8 : vector<4xi32>, + %arg9 : vector<16xi32>) -> vector<4 x f32> { + %csti32 = llvm.mlir.constant(42 : i32) : i32 + + // CHECK-LABEL: rocdl.smfmac + // CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 : + (vector<4xf16>, vector<8xf16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 : + (vector<4xf16>, vector<8xf16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 : + (vector<4xi16>, vector<8xi16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 : + (vector<4xi16>, vector<8xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xi32>, + i32, i32, i32) -> vector<4xi32> + + // CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xi32>, + i32, i32, i32) -> vector<16xi32> + + // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 : + (vector<2xi32>, vector<4xi32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + llvm.return %r0 : vector<4 x f32> +} + llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32, %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>, %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {