Skip to content

Commit 9da0ef1

Browse files
authored
Fix complex tanh overflows. (#88708)
This ports the XLA lowering and was verified using XLA's exhaustive_unary_test_complex test.
1 parent 14774ad commit 9da0ef1

File tree

2 files changed

+157
-35
lines changed

2 files changed

+157
-35
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -978,30 +978,84 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
978978
LogicalResult
979979
matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
980980
ConversionPatternRewriter &rewriter) const override {
981+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
981982
auto loc = op.getLoc();
982983
auto type = cast<ComplexType>(adaptor.getComplex().getType());
983984
auto elementType = cast<FloatType>(type.getElementType());
984-
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
985+
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
986+
const auto &floatSemantics = elementType.getFloatSemantics();
985987

986-
// The hyperbolic tangent for complex number can be calculated as follows.
987-
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
988-
// See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
989988
Value real =
990-
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
989+
b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
991990
Value imag =
992-
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
993-
Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
994-
Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
995-
Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
996-
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
997-
Value numerator =
998-
rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
999-
Value one = rewriter.create<arith::ConstantOp>(
1000-
loc, elementType, rewriter.getFloatAttr(elementType, 1));
1001-
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
1002-
Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
1003-
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
1004-
fmf);
991+
b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
992+
993+
auto cst = [&](APFloat v) {
994+
return b.create<arith::ConstantOp>(elementType,
995+
b.getFloatAttr(elementType, v));
996+
};
997+
Value inf = cst(APFloat::getInf(floatSemantics));
998+
Value negOne = b.create<arith::ConstantOp>(
999+
elementType, b.getFloatAttr(elementType, -1.0));
1000+
Value four = b.create<arith::ConstantOp>(elementType,
1001+
b.getFloatAttr(elementType, 4.0));
1002+
Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1003+
Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1004+
1005+
Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1006+
Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1007+
Value realNum =
1008+
b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1009+
1010+
Value cosImag = b.create<math::CosOp>(imag, fmf);
1011+
Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1012+
Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1013+
Value sinImag = b.create<math::SinOp>(imag, fmf);
1014+
1015+
Value imagNum = b.create<arith::MulFOp>(
1016+
four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1017+
1018+
Value expSumMinusTwo =
1019+
b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1020+
Value denom =
1021+
b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1022+
1023+
Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1024+
expSumMinusTwo, inf, fmf);
1025+
Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1026+
1027+
Value resultReal = b.create<arith::SelectOp>(
1028+
isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1029+
Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1030+
1031+
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1032+
arith::FastMathFlags::ninf)) {
1033+
Value absReal = b.create<math::AbsFOp>(real, fmf);
1034+
Value zero = b.create<arith::ConstantOp>(
1035+
elementType, b.getFloatAttr(elementType, 0.0));
1036+
Value nan = cst(APFloat::getNaN(floatSemantics));
1037+
1038+
Value absRealIsInf =
1039+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1040+
Value imagIsZero =
1041+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1042+
Value absRealIsNotInf = b.create<arith::XOrIOp>(
1043+
absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1044+
1045+
Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1046+
imagNum, imagNum, fmf);
1047+
Value resultRealIsNaN =
1048+
b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1049+
Value resultImagIsZero = b.create<arith::OrIOp>(
1050+
imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1051+
1052+
resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1053+
resultImag =
1054+
b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1055+
}
1056+
1057+
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1058+
resultImag);
10051059
return success();
10061060
}
10071061
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -679,14 +679,42 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
679679
}
680680
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
681681
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
682-
// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] : f32
683-
// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] : f32
684-
// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] : f32
685-
// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] : f32
686-
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
687-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
688-
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
689-
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
682+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
683+
// CHECK: %[[NEG_ONE:.*]] = arith.constant -1.000000e+00 : f32
684+
// CHECK: %[[FOUR:.*]] = arith.constant 4.000000e+00 : f32
685+
// CHECK: %[[TWO_REAL:.*]] = arith.addf %[[REAL]], %[[REAL]] : f32
686+
// CHECK: %[[NEG_TWO_REAL:.*]] = arith.mulf %[[NEG_ONE]], %[[TWO_REAL]] : f32
687+
// CHECK: %[[EXPM1:.*]] = math.expm1 %[[TWO_REAL]] : f32
688+
// CHECK: %[[EXPM1_2:.*]] = math.expm1 %[[NEG_TWO_REAL]] : f32
689+
// CHECK: %[[REAL_NUM:.*]] = arith.subf %[[EXPM1]], %[[EXPM1_2]] : f32
690+
// CHECK: %[[COS:.*]] = math.cos %[[IMAG]] : f32
691+
// CHECK: %[[COS_SQ:.*]] = arith.mulf %[[COS]], %[[COS]] : f32
692+
// CHECK: %[[FOUR_COS_SQ:.*]] = arith.mulf %[[COS_SQ]], %[[FOUR]] : f32
693+
// CHECK: %[[SIN:.*]] = math.sin %[[IMAG]] : f32
694+
// CHECK: %[[MUL:.*]] = arith.mulf %[[COS]], %[[SIN]] : f32
695+
// CHECK: %[[IMAG_NUM:.*]] = arith.mulf %[[FOUR]], %[[MUL]] : f32
696+
// CHECK: %[[ADD:.*]] = arith.addf %[[EXPM1]], %[[EXPM1_2]] : f32
697+
// CHECK: %[[DENOM:.*]] = arith.addf %[[ADD]], %[[FOUR_COS_SQ]] : f32
698+
// CHECK: %[[IS_INF:.*]] = arith.cmpf oeq, %[[ADD]], %[[INF]] : f32
699+
// CHECK: %[[LIMIT:.*]] = math.copysign %[[NEG_ONE]], %[[REAL]] : f32
700+
// CHECK: %[[RESULT_REAL:.*]] = arith.divf %[[REAL_NUM]], %[[DENOM]] : f32
701+
// CHECK: %[[RESULT_REAL2:.*]] = arith.select %[[IS_INF]], %[[LIMIT]], %[[RESULT_REAL]] : f32
702+
// CHECK: %[[RESULT_IMAG:.*]] = arith.divf %[[IMAG_NUM]], %[[DENOM]] : f32
703+
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
704+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
705+
// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
706+
// CHECK: %[[ABS_REAL_INF:.*]] = arith.cmpf oeq, %[[ABS_REAL]], %[[INF]] : f32
707+
// CHECK: %[[IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
708+
// CHECK: %true = arith.constant true
709+
// CHECK: %[[ABS_REAL_NOT_INF:.*]] = arith.xori %[[ABS_REAL_INF]], %true : i1
710+
// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG_NUM]], %[[IMAG_NUM]] : f32
711+
// CHECK: %[[REAL_IS_NAN:.*]] = arith.andi %[[IMAG_IS_NAN]], %[[ABS_REAL_NOT_INF]] : i1
712+
// CHECK: %[[AND:.*]] = arith.andi %[[ABS_REAL_INF]], %[[IMAG_IS_NAN]] : i1
713+
// CHECK: %[[IMAG_IS_NAN2:.*]] = arith.ori %[[IMAG_ZERO]], %[[AND]] : i1
714+
// CHECK: %[[RESULT_REAL3:.*]] = arith.select %[[REAL_IS_NAN]], %[[NAN]], %[[RESULT_REAL2]] : f32
715+
// CHECK: %[[RESULT_IMAG2:.*]] = arith.select %[[IMAG_IS_NAN2]], %[[ZERO]], %[[RESULT_IMAG]] : f32
716+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL3]], %[[RESULT_IMAG2]] : complex<f32>
717+
// CHECK: return %[[RESULT]] : complex<f32>
690718

691719
// -----
692720

@@ -2100,7 +2128,6 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
21002128
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
21012129
// CHECK: return %[[RESULT]] : complex<f32>
21022130

2103-
21042131
// -----
21052132

21062133
// CHECK-LABEL: func @complex_tanh_with_fmf
@@ -2109,13 +2136,54 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
21092136
%tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
21102137
return %tanh : complex<f32>
21112138
}
2139+
21122140
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
21132141
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
2114-
// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
2115-
// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
2116-
// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
2117-
// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
2118-
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
2119-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
2120-
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
2121-
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
2142+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
2143+
// CHECK: %[[NEG_ONE:.*]] = arith.constant -1.000000e+00 : f32
2144+
// CHECK: %[[FOUR:.*]] = arith.constant 4.000000e+00 : f32
2145+
// CHECK: %[[TWO_REAL:.*]] = arith.addf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
2146+
// CHECK: %[[NEG_TWO_REAL:.*]] = arith.mulf %[[NEG_ONE]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
2147+
// CHECK: %[[EXPM1:.*]] = math.expm1 %[[TWO_REAL]] fastmath<nnan,contract> : f32
2148+
// CHECK: %[[EXPM1_2:.*]] = math.expm1 %[[NEG_TWO_REAL]] fastmath<nnan,contract> : f32
2149+
// CHECK: %[[REAL_NUM:.*]] = arith.subf %[[EXPM1]], %[[EXPM1_2]] fastmath<nnan,contract> : f32
2150+
// CHECK: %[[COS:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
2151+
// CHECK: %[[COS_SQ:.*]] = arith.mulf %[[COS]], %[[COS]] fastmath<nnan,contract> : f32
2152+
// CHECK: %[[FOUR_COS_SQ:.*]] = arith.mulf %[[COS_SQ]], %[[FOUR]] fastmath<nnan,contract> : f32
2153+
// CHECK: %[[SIN:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
2154+
// CHECK: %[[MUL:.*]] = arith.mulf %[[COS]], %[[SIN]] fastmath<nnan,contract> : f32
2155+
// CHECK: %[[IMAG_NUM:.*]] = arith.mulf %[[FOUR]], %[[MUL]] fastmath<nnan,contract> : f32
2156+
// CHECK: %[[ADD:.*]] = arith.addf %[[EXPM1]], %[[EXPM1_2]] fastmath<nnan,contract> : f32
2157+
// CHECK: %[[DENOM:.*]] = arith.addf %[[ADD]], %[[FOUR_COS_SQ]] fastmath<nnan,contract> : f32
2158+
// CHECK: %[[IS_INF:.*]] = arith.cmpf oeq, %[[ADD]], %[[INF]] fastmath<nnan,contract> : f32
2159+
// CHECK: %[[LIMIT:.*]] = math.copysign %[[NEG_ONE]], %[[REAL]] fastmath<nnan,contract> : f32
2160+
// CHECK: %[[RESULT_REAL:.*]] = arith.divf %[[REAL_NUM]], %[[DENOM]] fastmath<nnan,contract> : f32
2161+
// CHECK: %[[RESULT_REAL2:.*]] = arith.select %[[IS_INF]], %[[LIMIT]], %[[RESULT_REAL]] : f32
2162+
// CHECK: %[[RESULT_IMAG:.*]] = arith.divf %[[IMAG_NUM]], %[[DENOM]] fastmath<nnan,contract> : f32
2163+
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
2164+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
2165+
// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
2166+
// CHECK: %[[ABS_REAL_INF:.*]] = arith.cmpf oeq, %[[ABS_REAL]], %[[INF]] fastmath<nnan,contract> : f32
2167+
// CHECK: %[[IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] fastmath<nnan,contract> : f32
2168+
// CHECK: %true = arith.constant true
2169+
// CHECK: %[[ABS_REAL_NOT_INF:.*]] = arith.xori %[[ABS_REAL_INF]], %true : i1
2170+
// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG_NUM]], %[[IMAG_NUM]] fastmath<nnan,contract> : f32
2171+
// CHECK: %[[REAL_IS_NAN:.*]] = arith.andi %[[IMAG_IS_NAN]], %[[ABS_REAL_NOT_INF]] : i1
2172+
// CHECK: %[[AND:.*]] = arith.andi %[[ABS_REAL_INF]], %[[IMAG_IS_NAN]] : i1
2173+
// CHECK: %[[IMAG_IS_NAN2:.*]] = arith.ori %[[IMAG_ZERO]], %[[AND]] : i1
2174+
// CHECK: %[[RESULT_REAL3:.*]] = arith.select %[[REAL_IS_NAN]], %[[NAN]], %[[RESULT_REAL2]] : f32
2175+
// CHECK: %[[RESULT_IMAG2:.*]] = arith.select %[[IMAG_IS_NAN2]], %[[ZERO]], %[[RESULT_IMAG]] : f32
2176+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL3]], %[[RESULT_IMAG2]] : complex<f32>
2177+
// CHECK: return %[[RESULT]] : complex<f32>
2178+
2179+
// -----
2180+
2181+
// CHECK-LABEL: func @complex_tanh_nnan_ninf
2182+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
2183+
func.func @complex_tanh_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
2184+
%tanh = complex.tanh %arg fastmath<nnan,ninf> : complex<f32>
2185+
return %tanh : complex<f32>
2186+
}
2187+
2188+
// CHECK-COUNT-1: arith.select
2189+
// CHECK-NOT: arith.select

0 commit comments

Comments
 (0)