diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 8759f1dc3269d..8b687a7f29bef 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1360,6 +1360,37 @@ def ROCDL_CvtScaleF32PkFp4F32Op : }]; } +//===----------------------------------------------------------------------===// +// FMED3 operations +//===----------------------------------------------------------------------===// + +def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>, + Arguments<(ins LLVM_ScalarOrVectorOf:$src0, + LLVM_ScalarOrVectorOf:$src1, + LLVM_ScalarOrVectorOf:$src2)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let summary = "Median of three float/half values"; + let description = [{ + Computes the median of three floating-point values using the AMDGPU fmed3 intrinsic. + This operation is equivalent to `max(min(a, b), min(max(a, b), c))` but uses the + hardware-accelerated V_MED3_F16/V_MED3_F32 instruction for better performance. + + The operation supports both scalar and vector floating-point types (f16, f32). + + Example: + ```mlir + // Scalar f32 median + %result = rocdl.fmed3 %a, %b, %c : f32 + + // Vector f16 median + %result = rocdl.fmed3 %va, %vb, %vc : vector<4xf16> + ``` + }]; + let assemblyFormat = [{ + $src0 `,` $src1 `,` $src2 attr-dict `:` type($res) + }]; +} + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index e127fdb78a861..0bad151570029 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -29,6 +29,20 @@ func.func @rocdl_special_regs() -> i32 { llvm.return %0 : i32 } +func.func @rocdl.fmed3.scalar(%a: f32, %b: f32, %c: f32) -> f32 { + // CHECK-LABEL: rocdl.fmed3.scalar + // CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32 + %0 = rocdl.fmed3 %a, %b, %c : f32 + llvm.return %0 : f32 +} + +func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf16>) -> vector<4xf16> { + // CHECK-LABEL: rocdl.fmed3.vector + // CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : vector<4xf16> + %0 = rocdl.fmed3 %a, %b, %c : vector<4xf16> + llvm.return %0 : vector<4xf16> +} + func.func @rocdl.barrier() { // CHECK: rocdl.barrier rocdl.barrier diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index c629877b69b4e..e043a8c533d05 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1298,6 +1298,20 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 { llvm.return %ret : i32 } +llvm.func @test_fmed3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 { + // CHECK-LABEL: define half @test_fmed3_f16(half %0, half %1, half %2) + %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f16 + llvm.return %0 : f16 + // CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2) +} + +llvm.func @test_fmed3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 { + // CHECK-LABEL: define float @test_fmed3_f32(float %0, float %1, float %2) + %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32 + llvm.return %0 : f32 + // CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2) +} + // CHECK-LABEL: rocdl.cvt.scale.pk8 // CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {