Skip to content

Commit 829dc0f

Browse files
Variadic rocdl.fmed3 op; print tests; addresed comments
Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 776fc40 commit 829dc0f

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,22 +1292,33 @@ def ROCDL_CvtScaleF32PkFp4F32Op :
12921292
}
12931293

12941294
//===----------------------------------------------------------------------===//
1295-
// MED3 operations
1295+
// FMED3 operations
12961296
//===----------------------------------------------------------------------===//
12971297

1298-
def ROCDL_Med3Op : ROCDL_ConcreteNonMemIntrOp<"med3", [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
1298+
def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
12991299
Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src0,
13001300
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src1,
13011301
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src2)> {
13021302
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$res);
13031303
let summary = "Median of three float/half values";
1304-
let assemblyFormat = [{
1305-
$src0 `,` $src1 `,` $src2 attr-dict `:` `(` type($src0) `,` type($src1) `,` type($src2) `)` `->` type($res)
1304+
let description = [{
1305+
Computes the median of three floating-point values using the AMDGPU fmed3 intrinsic.
1306+
This operation is equivalent to `max(min(a, b), min(max(a, b), c))` but uses the
1307+
hardware-accelerated V_MED3_F16/V_MED3_F32 instruction for better performance.
1308+
1309+
The operation supports both scalar and vector floating-point types (f16, f32).
1310+
1311+
Example:
1312+
```mlir
1313+
// Scalar f32 median
1314+
%result = rocdl.fmed3 %a, %b, %c : f32
1315+
1316+
// Vector f16 median
1317+
%result = rocdl.fmed3 %va, %vb, %vc : vector<4xf16>
1318+
```
13061319
}];
1307-
string llvmBuilder = [{
1308-
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_fmed3,
1309-
{$src0, $src1, $src2},
1310-
{moduleTranslation.convertType(op.getRes().getType())});
1320+
let assemblyFormat = [{
1321+
$src0 `,` $src1 `,` $src2 attr-dict `:` type($res)
13111322
}];
13121323
}
13131324

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ func.func @rocdl_special_regs() -> i32 {
2929
llvm.return %0 : i32
3030
}
3131

32+
func.func @rocdl.fmed3.scalar(%a: f32, %b: f32, %c: f32) -> f32 {
33+
// CHECK-LABEL: rocdl.fmed3.scalar
34+
// CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
35+
%0 = rocdl.fmed3 %a, %b, %c : f32
36+
llvm.return %0 : f32
37+
}
38+
39+
func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf16>) -> vector<4xf16> {
40+
// CHECK-LABEL: rocdl.fmed3.vector
41+
// CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : vector<4xf16>
42+
%0 = rocdl.fmed3 %a, %b, %c : vector<4xf16>
43+
llvm.return %0 : vector<4xf16>
44+
}
45+
3246
func.func @rocdl.barrier() {
3347
// CHECK: rocdl.barrier
3448
rocdl.barrier

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,16 +1298,16 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
12981298
llvm.return %ret : i32
12991299
}
13001300

1301-
llvm.func @test_med3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
1302-
// CHECK-LABEL: define half @test_med3_f16(half %0, half %1, half %2)
1303-
%0 = rocdl.med3 %arg0, %arg1, %arg2 : (f16, f16, f16) -> f16
1301+
llvm.func @test_fmed3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
1302+
// CHECK-LABEL: define half @test_fmed3_f16(half %0, half %1, half %2)
1303+
%0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f16
13041304
llvm.return %0 : f16
13051305
// CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2)
13061306
}
13071307

1308-
llvm.func @test_med3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
1309-
// CHECK-LABEL: define float @test_med3_f32(float %0, float %1, float %2)
1310-
%0 = rocdl.med3 %arg0, %arg1, %arg2 : (f32, f32, f32) -> f32
1308+
llvm.func @test_fmed3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
1309+
// CHECK-LABEL: define float @test_fmed3_f32(float %0, float %1, float %2)
1310+
%0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
13111311
llvm.return %0 : f32
13121312
// CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2)
13131313
}

0 commit comments

Comments
 (0)