From aac972c877bcb8650f272f1952f07cdeb1c77422 Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Thu, 16 Jan 2025 20:59:00 +0000 Subject: [PATCH 1/3] Add gfx950 mfma instructions to ROCDL dialect Add ROCDL support to the following instructions: V_MFMA_F32_16X16X32_BF16 V_MFMA_I32_16X16X64_I8 V_MFMA_F32_16X16X32_F16 V_MFMA_F32_32X32X16_BF16 V_MFMA_I32_32X32X32_I8 V_MFMA_F32_32X32X16_F16 --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 7 ++++ mlir/test/Target/LLVMIR/rocdl.mlir | 35 +++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 71dac3ad39b7b..720c999f23025 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -388,6 +388,13 @@ def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.b def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">; def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">; def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">; +// New in gfx950 +def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16">; +def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8">; +def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">; +def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16">; +def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">; +def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">; //===---------------------------------------------------------------------===// // WMMA intrinsics diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 0620c23b5fdad..906f4d545d41b 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -219,7 +219,9 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, %arg4 : vector<16 x f32>, %arg5 : vector<4xf32>, %arg6 : vector<4xf16>, %arg7 : vector<32 x i32>, %arg8 : vector<16 x i32>, %arg9 : vector<4xi32>, - %arg10 : vector<2xi16>, %arg11 : i64) -> vector<32 x f32> { + %arg10 : vector<2xi16>, %arg11 : i64, + %arg12 : vector<8xbf16>, %arg13 : vector<4xi32>, + %arg14 : vector<8xf16>) -> vector<32 x f32> { %csti32 = llvm.mlir.constant(42 : i32) : i32 // CHECK-LABEL: rocdl.xdlops @@ -362,6 +364,37 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 : (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf16(<8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12, %arg12, %arg5, %csti32, %csti32, %csti32 : + (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x64.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r29 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, %csti32, %csti32, %csti32 : + (vector<4xi32>, vector<4xi32>, vector<4xi32>, + i32, i32, i32) -> vector<4xi32> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r30 = rocdl.mfma.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 : + (vector<8xf16>, vector<8xf16>, vector<4xf32>, + i32, i32, i32) -> vector<4xi32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12, %arg12, %arg4, %csti32, %csti32, %csti32 : + (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x32.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r32 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, %csti32, %csti32, %csti32 : + (vector<4xi32>, vector<4xi32>, vector<16xi32>, + i32, i32, i32) -> vector<16xi32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r33 = rocdl.mfma.f32.32x32x16.f16 %arg14, %arg14, %arg4, %csti32, %csti32, %csti32 : + (vector<8xf16>, vector<8xf16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + llvm.return %r0 : vector<32 x f32> } From 05785e6ecf6221075dacc4c82ee1cba256d80973 Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Fri, 17 Jan 2025 21:49:04 +0000 Subject: [PATCH 2/3] Remove an old comment --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 720c999f23025..37a44292af952 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -379,7 +379,6 @@ def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">; def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">; def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">; def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">; -// fp8, only on gfx940 def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">; def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">; def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">; From 342573d90e8c1832c70511609a1d484e8b0c5070 Mon Sep 17 00:00:00 2001 From: Yi Qian <68618497+yiqian1@users.noreply.github.com> Date: Mon, 20 Jan 2025 09:56:37 -0600 Subject: [PATCH 3/3] Update a comment. Co-authored-by: Jakub Kuderski --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 37a44292af952..2f976e41fc305 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -387,7 +387,7 @@ def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.b def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">; def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">; def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">; -// New in gfx950 +// New in gfx950. def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16">; def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8">; def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">;