From 783737e69aaf9be791d43b85e808fbdd87baec71 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Thu, 17 Apr 2025 15:50:49 -0700 Subject: [PATCH 1/2] [MLIR][LLVMIR] Extend llrint,lrint,lround for vectors of float Matching langref. Note that `llround` is different than the rest. --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 12 +++-- mlir/test/Target/LLVMIR/Import/intrinsic.ll | 46 +++++++++++++------ .../test/Target/LLVMIR/llvmir-intrinsics.mlir | 29 ++++++++++-- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index ab928c9e2d0e7..bc85881c94cd9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -153,16 +153,18 @@ def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">; def LLVM_PowIOp : LLVM_PowFI<"powi">; def LLVM_RintOp : LLVM_UnaryIntrOpF<"rint">; def LLVM_NearbyintOp : LLVM_UnaryIntrOpF<"nearbyint">; -class LLVM_IntRoundIntrOpBase : +class LLVM_IntRoundIntrOpBase : LLVM_OneResultIntrOp { - let arguments = (ins LLVM_AnyFloat:$val); + let arguments = (ins element:$val); let assemblyFormat = "`(` operands `)` attr-dict `:` " "functional-type(operands, results)"; } -def LLVM_LroundOp : LLVM_IntRoundIntrOpBase<"lround">; +class LLVM_IntRoundIntrVecOrFloatOpBase : + LLVM_IntRoundIntrOpBase>; +def LLVM_LroundOp : LLVM_IntRoundIntrVecOrFloatOpBase<"lround">; def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">; -def LLVM_LrintOp : LLVM_IntRoundIntrOpBase<"lrint">; -def LLVM_LlrintOp : LLVM_IntRoundIntrOpBase<"llrint">; +def LLVM_LrintOp : LLVM_IntRoundIntrVecOrFloatOpBase<"lrint">; +def LLVM_LlrintOp : LLVM_IntRoundIntrVecOrFloatOpBase<"llrint">; def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">; def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">; def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">; diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index b0a36939d8c48..36afa84a031f4 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -235,15 +235,23 @@ define void @nearbyint_test(float %0, double %1, <8 x float> %2, <8 x double> %3 ret void } ; CHECK-LABEL: llvm.func @lround_test -define void @lround_test(float %0, double %1) { +define void @lround_test(float %0, double %1, <2 x float> %2, <2 x double> %3) { ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32 - %3 = call i32 @llvm.lround.i32.f32(float %0) + %5 = call i32 @llvm.lround.i32.f32(float %0) ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i64 - %4 = call i64 @llvm.lround.i64.f32(float %0) + %6 = call i64 @llvm.lround.i64.f32(float %0) ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i32 - %5 = call i32 @llvm.lround.i32.f64(double %1) + %7 = call i32 @llvm.lround.i32.f64(double %1) ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i64 - %6 = call i64 @llvm.lround.i64.f64(double %1) + %8 = call i64 @llvm.lround.i64.f64(double %1) + ; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf32>) -> vector<2xi32> + %9 = call <2 x i32> @llvm.lround.v2i32.v2f32(<2 x float> %2) + ; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf64>) -> vector<2xi32> + %10 = call <2 x i32> @llvm.lround.v2i32.v2f64(<2 x double> %3) + ; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf32>) -> vector<2xi64> + %11 = call <2 x i64> @llvm.lround.v2i64.v2f32(<2 x float> %2) + ; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf64>) -> vector<2xi64> + %12 = call <2 x i64> @llvm.lround.v2i64.v2f64(<2 x double> %3) ret void } ; CHECK-LABEL: llvm.func @llround_test @@ -255,23 +263,35 @@ define void @llround_test(float %0, double %1) { ret void } ; CHECK-LABEL: llvm.func @lrint_test -define void @lrint_test(float %0, double %1) { +define void @lrint_test(float %0, double %1, <2 x float> %2, <2 x double> %3) { ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i32 - %3 = call i32 @llvm.lrint.i32.f32(float %0) + %5 = call i32 @llvm.lrint.i32.f32(float %0) ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i64 - %4 = call i64 @llvm.lrint.i64.f32(float %0) + %6 = call i64 @llvm.lrint.i64.f32(float %0) ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i32 - %5 = call i32 @llvm.lrint.i32.f64(double %1) + %7 = call i32 @llvm.lrint.i32.f64(double %1) ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i64 - %6 = call i64 @llvm.lrint.i64.f64(double %1) + %8 = call i64 @llvm.lrint.i64.f64(double %1) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi32> + %9 = call <2 x i32> @llvm.lrint.v2i32.v2f32(<2 x float> %2) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi32> + %10 = call <2 x i32> @llvm.lrint.v2i32.v2f64(<2 x double> %3) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi64> + %11 = call <2 x i64> @llvm.lrint.v2i64.v2f32(<2 x float> %2) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi64> + %12 = call <2 x i64> @llvm.lrint.v2i64.v2f64(<2 x double> %3) ret void } ; CHECK-LABEL: llvm.func @llrint_test -define void @llrint_test(float %0, double %1) { +define void @llrint_test(float %0, double %1, <2 x float> %2, <2 x double> %3) { ; CHECK: llvm.intr.llrint(%{{.*}}) : (f32) -> i64 - %3 = call i64 @llvm.llrint.i64.f32(float %0) + %5 = call i64 @llvm.llrint.i64.f32(float %0) ; CHECK: llvm.intr.llrint(%{{.*}}) : (f64) -> i64 - %4 = call i64 @llvm.llrint.i64.f64(double %1) + %6 = call i64 @llvm.llrint.i64.f64(double %1) + ; CHECK: llvm.intr.llrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi64> + %7 = call <2 x i64> @llvm.llrint.v2i64.v2f32(<2 x float> %2) + ; CHECK: llvm.intr.llrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi64> + %8 = call <2 x i64> @llvm.llrint.v2i64.v2f64(<2 x double> %3) ret void } diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 8088dec811e13..ba12140c59b35 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -242,7 +242,8 @@ llvm.func @nearbyint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 } // CHECK-LABEL: @lround_test -llvm.func @lround_test(%arg0 : f32, %arg1 : f64) { +llvm.func @lround_test(%arg0 : f32, %arg1 : f64, + %arg2 : vector<2xf32>, %arg3 : vector<2xf64>) { // CHECK: call i32 @llvm.lround.i32.f32 "llvm.intr.lround"(%arg0) : (f32) -> i32 // CHECK: call i64 @llvm.lround.i64.f32 @@ -251,6 +252,14 @@ llvm.func @lround_test(%arg0 : f32, %arg1 : f64) { "llvm.intr.lround"(%arg1) : (f64) -> i32 // CHECK: call i64 @llvm.lround.i64.f64 "llvm.intr.lround"(%arg1) : (f64) -> i64 + // CHECK: call <2 x i32> @llvm.lround.v2i32.v2f32 + "llvm.intr.lround"(%arg2) : (vector<2xf32>) -> vector<2xi32> + // CHECK: call <2 x i32> @llvm.lround.v2i32.v2f64 + "llvm.intr.lround"(%arg3) : (vector<2xf64>) -> vector<2xi32> + // CHECK: call <2 x i64> @llvm.lround.v2i64.v2f32 + "llvm.intr.lround"(%arg2) : (vector<2xf32>) -> vector<2xi64> + // CHECK: call <2 x i64> @llvm.lround.v2i64.v2f64 + "llvm.intr.lround"(%arg3) : (vector<2xf64>) -> vector<2xi64> llvm.return } @@ -264,7 +273,8 @@ llvm.func @llround_test(%arg0 : f32, %arg1 : f64) { } // CHECK-LABEL: @lrint_test -llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) { +llvm.func @lrint_test(%arg0 : f32, %arg1 : f64, + %arg2 : vector<2xf32>, %arg3 : vector<2xf64>) { // CHECK: call i32 @llvm.lrint.i32.f32 "llvm.intr.lrint"(%arg0) : (f32) -> i32 // CHECK: call i64 @llvm.lrint.i64.f32 @@ -273,15 +283,28 @@ llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) { "llvm.intr.lrint"(%arg1) : (f64) -> i32 // CHECK: call i64 @llvm.lrint.i64.f64 "llvm.intr.lrint"(%arg1) : (f64) -> i64 + // CHECK: call <2 x i32> @llvm.lrint.v2i32.v2f32 + "llvm.intr.lrint"(%arg2) : (vector<2xf32>) -> vector<2xi32> + // CHECK: call <2 x i32> @llvm.lrint.v2i32.v2f64 + "llvm.intr.lrint"(%arg3) : (vector<2xf64>) -> vector<2xi32> + // CHECK: call <2 x i64> @llvm.lrint.v2i64.v2f32 + "llvm.intr.lrint"(%arg2) : (vector<2xf32>) -> vector<2xi64> + // CHECK: call <2 x i64> @llvm.lrint.v2i64.v2f64 + "llvm.intr.lrint"(%arg3) : (vector<2xf64>) -> vector<2xi64> llvm.return } // CHECK-LABEL: @llrint_test -llvm.func @llrint_test(%arg0 : f32, %arg1 : f64) { +llvm.func @llrint_test(%arg0 : f32, %arg1 : f64, + %arg2 : vector<2xf32>, %arg3 : vector<2xf64>) { // CHECK: call i64 @llvm.llrint.i64.f32 "llvm.intr.llrint"(%arg0) : (f32) -> i64 // CHECK: call i64 @llvm.llrint.i64.f64 "llvm.intr.llrint"(%arg1) : (f64) -> i64 + // CHECK: call <2 x i64> @llvm.llrint.v2i64.v2f32 + "llvm.intr.llrint"(%arg2) : (vector<2xf32>) -> vector<2xi64> + // CHECK: call <2 x i64> @llvm.llrint.v2i64.v2f64 + "llvm.intr.llrint"(%arg3) : (vector<2xf64>) -> vector<2xi64> llvm.return } From 20db049306cac34c6f4b57e4566c914d2f54e0b9 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Fri, 18 Apr 2025 10:26:00 -0700 Subject: [PATCH 2/2] Address review --- mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index bc85881c94cd9..cd8b68e5b1410 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -159,12 +159,12 @@ class LLVM_IntRoundIntrOpBase : let assemblyFormat = "`(` operands `)` attr-dict `:` " "functional-type(operands, results)"; } -class LLVM_IntRoundIntrVecOrFloatOpBase : +class LLVM_ScalarOrVectorIntRoundIntrOpBase : LLVM_IntRoundIntrOpBase>; -def LLVM_LroundOp : LLVM_IntRoundIntrVecOrFloatOpBase<"lround">; +def LLVM_LroundOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"lround">; def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">; -def LLVM_LrintOp : LLVM_IntRoundIntrVecOrFloatOpBase<"lrint">; -def LLVM_LlrintOp : LLVM_IntRoundIntrVecOrFloatOpBase<"llrint">; +def LLVM_LrintOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"lrint">; +def LLVM_LlrintOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"llrint">; def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">; def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">; def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">;