Skip to content

Commit d30bd27

Browse files
[mlir][complex] Fix exp accuracy (#164952)
This ports openxla/stablehlo#2682 implementation by @pearu. Three tests were added to `Integration/Dialect/Complex/CPU/correctness.mlir`. I also verified accuracy using XLA's complex_unary_op_test and its MLIR emitters.
1 parent 566c731 commit d30bd27

File tree

3 files changed

+107
-19
lines changed

3 files changed

+107
-19
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -313,25 +313,53 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
313313
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
314314
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
315315

316+
// exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317+
// Handle special cases as StableHLO implementation does:
318+
// 1. When b == 0, set imag(exp(z)) = 0
319+
// 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
316320
LogicalResult
317321
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
318322
ConversionPatternRewriter &rewriter) const override {
319323
auto loc = op.getLoc();
320324
auto type = cast<ComplexType>(adaptor.getComplex().getType());
321-
auto elementType = cast<FloatType>(type.getElementType());
322-
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
323-
324-
Value real =
325-
complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
326-
Value imag =
327-
complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
328-
Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
329-
Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
325+
auto ET = cast<FloatType>(type.getElementType());
326+
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327+
const auto &floatSemantics = ET.getFloatSemantics();
328+
ImplicitLocOpBuilder b(loc, rewriter);
329+
330+
Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331+
Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
332+
Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
333+
Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
334+
Value inf = arith::ConstantOp::create(
335+
b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
336+
337+
Value exp = math::ExpOp::create(b, x, fmf);
338+
Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339+
Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340+
Value cos = math::CosOp::create(b, y, fmf);
341+
Value sin = math::SinOp::create(b, y, fmf);
342+
343+
Value expIsInf =
344+
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345+
Value yIsZero =
346+
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
347+
348+
// Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349+
Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350+
Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351+
Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
330352
Value resultReal =
331-
arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
332-
Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
333-
Value resultImag =
334-
arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
353+
arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
354+
355+
// Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356+
// exp(x/2)*sin(y)*exp(x/2)
357+
Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358+
Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359+
Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
360+
Value imagNonZero =
361+
arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362+
Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
335363

336364
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
337365
resultImag);

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,25 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
211211
}
212212
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
213213
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
214-
// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
214+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
215+
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
216+
// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
215217
// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32
216-
// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
218+
// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] : f32
219+
// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] : f32
220+
// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
217221
// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32
218-
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
222+
// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] : f32
223+
// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
224+
// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
225+
// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] : f32
226+
// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] : f32
227+
// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
228+
// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
229+
// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] : f32
230+
// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] : f32
231+
// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
232+
// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
219233
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
220234
// CHECK: return %[[RESULT]] : complex<f32>
221235

@@ -832,11 +846,25 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
832846
}
833847
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
834848
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
835-
// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
849+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
850+
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
851+
// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
836852
// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
837-
// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
853+
// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] fastmath<nnan,contract> : f32
854+
// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] fastmath<nnan,contract> : f32
855+
// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
838856
// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
839-
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
857+
// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] fastmath<nnan,contract> : f32
858+
// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
859+
// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
860+
// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
861+
// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
862+
// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
863+
// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
864+
// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
865+
// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
866+
// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
867+
// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
840868
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
841869
// CHECK: return %[[RESULT]] : complex<f32>
842870

mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ func.func @conj(%arg: complex<f32>) -> complex<f32> {
4949
func.return %conj : complex<f32>
5050
}
5151

52+
func.func @exp(%arg: complex<f32>) -> complex<f32> {
53+
%exp = complex.exp %arg : complex<f32>
54+
func.return %exp : complex<f32>
55+
}
56+
5257
// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
5358
func.func @test_binary(%input: tensor<?xcomplex<f32>>,
5459
%func: (complex<f32>, complex<f32>) -> complex<f32>) {
@@ -353,5 +358,32 @@ func.func @entry() {
353358
call @test_element_f64(%abs_test_cast, %abs_func)
354359
: (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
355360

361+
// complex.exp test
362+
%exp_test = arith.constant dense<[
363+
(1.0, 2.0),
364+
// CHECK: -1.1312
365+
// CHECK-NEXT: 2.4717
366+
367+
// The first case to consider is overflow of exp(real_part). If computed
368+
// directly, this yields inf * 0 = NaN, which is incorrect.
369+
(500.0, 0.0),
370+
// CHECK-NEXT: inf
371+
// CHECK-NOT: nan
372+
// CHECK-NEXT: 0
373+
374+
// In this case, the overflow of exp(real_part) is compensated when
375+
// sin(imag_part) is close to zero, yielding a finite imaginary part.
376+
(90.0238094, 5.900613e-39)
377+
// CHECK-NEXT: inf
378+
// CHECK-NOT: inf
379+
// CHECK-NEXT: 7.3746
380+
]> : tensor<3xcomplex<f32>>
381+
%exp_test_cast = tensor.cast %exp_test
382+
: tensor<3xcomplex<f32>> to tensor<?xcomplex<f32>>
383+
384+
%exp_func = func.constant @exp : (complex<f32>) -> complex<f32>
385+
call @test_unary(%exp_test_cast, %exp_func)
386+
: (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
387+
356388
func.return
357389
}

0 commit comments

Comments
 (0)