Skip to content

Commit ce07b95

Browse files
committed
[mlir][math] Support vector type by erf and round libm lowering
erf and round op are able to lowered to libm supporting vector type as other math operations. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D127934
1 parent 0baf13e commit ce07b95

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
142142
PatternBenefit benefit) {
143143
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
144144
VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
145-
VecOpToScalarOp<math::SinOp>>(patterns.getContext(), benefit);
145+
VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
146+
VecOpToScalarOp<math::RoundOp>>(patterns.getContext(), benefit);
146147
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
147148
PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
148-
PromoteOpToF32<math::SinOp>>(patterns.getContext(), benefit);
149+
PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
150+
PromoteOpToF32<math::RoundOp>>(patterns.getContext(), benefit);
149151
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
150152
"atan2f", "atan2", benefit);
151153
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",

mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,30 @@ func.func @erf_caller(%float: f32, %double: f64) -> (f32, f64) {
6363
return %float_result, %double_result : f32, f64
6464
}
6565

66+
// CHECK-LABEL: func @erf_vec_caller(
67+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
68+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
69+
func.func @erf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
70+
// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
71+
// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
72+
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
73+
// CHECK: %[[OUT0_F32:.*]] = call @erff(%[[IN0_F32]]) : (f32) -> f32
74+
// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
75+
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
76+
// CHECK: %[[OUT1_F32:.*]] = call @erff(%[[IN1_F32]]) : (f32) -> f32
77+
// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
78+
%float_result = math.erf %float : vector<2xf32>
79+
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
80+
// CHECK: %[[OUT0_F64:.*]] = call @erf(%[[IN0_F64]]) : (f64) -> f64
81+
// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
82+
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
83+
// CHECK: %[[OUT1_F64:.*]] = call @erf(%[[IN1_F64]]) : (f64) -> f64
84+
// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
85+
%double_result = math.erf %double : vector<2xf64>
86+
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
87+
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
88+
}
89+
6690
// CHECK-LABEL: func @expm1_caller
6791
// CHECK-SAME: %[[FLOAT:.*]]: f32
6892
// CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -157,3 +181,27 @@ func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64) {
157181
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
158182
return %float_result, %double_result : f32, f64
159183
}
184+
185+
// CHECK-LABEL: func @round_vec_caller(
186+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
187+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
188+
func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
189+
// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
190+
// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
191+
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
192+
// CHECK: %[[OUT0_F32:.*]] = call @roundf(%[[IN0_F32]]) : (f32) -> f32
193+
// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
194+
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
195+
// CHECK: %[[OUT1_F32:.*]] = call @roundf(%[[IN1_F32]]) : (f32) -> f32
196+
// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
197+
%float_result = math.round %float : vector<2xf32>
198+
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
199+
// CHECK: %[[OUT0_F64:.*]] = call @round(%[[IN0_F64]]) : (f64) -> f64
200+
// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
201+
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
202+
// CHECK: %[[OUT1_F64:.*]] = call @round(%[[IN1_F64]]) : (f64) -> f64
203+
// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
204+
%double_result = math.round %double : vector<2xf64>
205+
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
206+
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
207+
}

0 commit comments

Comments
 (0)