Skip to content

Commit edc76e1

Browse files
[ROCDL][LLVM] Added rocdl.fmed3 -> Intrinsic::amdgcn_fmed3 (#159332)
## Description Added ROCDL fmed3 op to support rewrite to `amdgcn_fmed3` intrinsic. ## Testing - ROCDL -> LLVMIR lit tests for new `rocdl.med3` ops in `/test/Target/LLVMIR/rocdl.mlir` Addresses [#157052](#157052) --------- Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 631b89c commit edc76e1

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,37 @@ def ROCDL_CvtScaleF32PkFp4F32Op :
13601360
}];
13611361
}
13621362

1363+
//===----------------------------------------------------------------------===//
1364+
// FMED3 operations
1365+
//===----------------------------------------------------------------------===//
1366+
1367+
def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
1368+
Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src0,
1369+
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src1,
1370+
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src2)> {
1371+
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$res);
1372+
let summary = "Median of three float/half values";
1373+
let description = [{
1374+
Computes the median of three floating-point values using the AMDGPU fmed3 intrinsic.
1375+
This operation is equivalent to `max(min(a, b), min(max(a, b), c))` but uses the
1376+
hardware-accelerated V_MED3_F16/V_MED3_F32 instruction for better performance.
1377+
1378+
The operation supports both scalar and vector floating-point types (f16, f32).
1379+
1380+
Example:
1381+
```mlir
1382+
// Scalar f32 median
1383+
%result = rocdl.fmed3 %a, %b, %c : f32
1384+
1385+
// Vector f16 median
1386+
%result = rocdl.fmed3 %va, %vb, %vc : vector<4xf16>
1387+
```
1388+
}];
1389+
let assemblyFormat = [{
1390+
$src0 `,` $src1 `,` $src2 attr-dict `:` type($res)
1391+
}];
1392+
}
1393+
13631394
//===----------------------------------------------------------------------===//
13641395
// ROCDL target attribute.
13651396
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ func.func @rocdl_special_regs() -> i32 {
2929
llvm.return %0 : i32
3030
}
3131

32+
func.func @rocdl.fmed3.scalar(%a: f32, %b: f32, %c: f32) -> f32 {
33+
// CHECK-LABEL: rocdl.fmed3.scalar
34+
// CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
35+
%0 = rocdl.fmed3 %a, %b, %c : f32
36+
llvm.return %0 : f32
37+
}
38+
39+
func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf16>) -> vector<4xf16> {
40+
// CHECK-LABEL: rocdl.fmed3.vector
41+
// CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : vector<4xf16>
42+
%0 = rocdl.fmed3 %a, %b, %c : vector<4xf16>
43+
llvm.return %0 : vector<4xf16>
44+
}
45+
3246
func.func @rocdl.barrier() {
3347
// CHECK: rocdl.barrier
3448
rocdl.barrier

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,20 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
12981298
llvm.return %ret : i32
12991299
}
13001300

1301+
llvm.func @test_fmed3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
1302+
// CHECK-LABEL: define half @test_fmed3_f16(half %0, half %1, half %2)
1303+
%0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f16
1304+
llvm.return %0 : f16
1305+
// CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2)
1306+
}
1307+
1308+
llvm.func @test_fmed3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
1309+
// CHECK-LABEL: define float @test_fmed3_f32(float %0, float %1, float %2)
1310+
%0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
1311+
llvm.return %0 : f32
1312+
// CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2)
1313+
}
1314+
13011315
// CHECK-LABEL: rocdl.cvt.scale.pk8
13021316
// CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 %[[SCALE:.+]])
13031317
llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {

0 commit comments

Comments
 (0)