Skip to content

Commit 7d2d8e2

Browse files
authored
[mlir][complex] Fastmath flag for the trigonometric ops in complex (#85563)
Support Fastmath flag to convert trigonometric ops in the complex dialect. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent 230b189 commit 7d2d8e2

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196196
auto loc = op.getLoc();
197197
auto type = cast<ComplexType>(adaptor.getComplex().getType());
198198
auto elementType = cast<FloatType>(type.getElementType());
199+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
199200

200201
Value real =
201202
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
207208
// implementation in the subclass to combine them.
208209
Value half = rewriter.create<arith::ConstantOp>(
209210
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
210-
Value exp = rewriter.create<math::ExpOp>(loc, imag);
211-
Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
212-
Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
213-
Value sin = rewriter.create<math::SinOp>(loc, real);
214-
Value cos = rewriter.create<math::CosOp>(loc, real);
211+
Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
212+
Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
213+
Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
214+
Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
215+
Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
215216

216217
auto resultPair =
217-
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
218+
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
218219

219220
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
220221
resultPair.second);
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
223224

224225
virtual std::pair<Value, Value>
225226
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
226-
Value cos, ConversionPatternRewriter &rewriter) const = 0;
227+
Value cos, ConversionPatternRewriter &rewriter,
228+
arith::FastMathFlagsAttr fmf) const = 0;
227229
};
228230

229231
struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
230232
using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
231233

232-
std::pair<Value, Value>
233-
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234-
Value cos, ConversionPatternRewriter &rewriter) const override {
234+
std::pair<Value, Value> combine(Location loc, Value scaledExp,
235+
Value reciprocalExp, Value sin, Value cos,
236+
ConversionPatternRewriter &rewriter,
237+
arith::FastMathFlagsAttr fmf) const override {
235238
// Complex cosine is defined as;
236239
// cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
237240
// Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
241244
// We get:
242245
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
243246
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
244-
Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
245-
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
246-
Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
247-
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
247+
Value sum =
248+
rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
249+
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
250+
Value diff =
251+
rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
252+
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
248253
return {resultReal, resultImag};
249254
}
250255
};
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
813818
struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
814819
using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
815820

816-
std::pair<Value, Value>
817-
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
818-
Value cos, ConversionPatternRewriter &rewriter) const override {
821+
std::pair<Value, Value> combine(Location loc, Value scaledExp,
822+
Value reciprocalExp, Value sin, Value cos,
823+
ConversionPatternRewriter &rewriter,
824+
arith::FastMathFlagsAttr fmf) const override {
819825
// Complex sine is defined as;
820826
// sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
821827
// Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
825831
// We get:
826832
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
827833
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
828-
Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
829-
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
830-
Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
831-
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
834+
Value sum =
835+
rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
836+
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
837+
Value diff =
838+
rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
839+
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
832840
return {resultReal, resultImag};
833841
}
834842
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,3 +1834,49 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
18341834
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
18351835
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
18361836
// CHECK: return %[[VAR41]] : complex<f32>
1837+
1838+
// -----
1839+
1840+
// CHECK-LABEL: func @complex_cos_with_fmf
1841+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
1842+
func.func @complex_cos_with_fmf(%arg: complex<f32>) -> complex<f32> {
1843+
%cos = complex.cos %arg fastmath<nnan,contract> : complex<f32>
1844+
return %cos : complex<f32>
1845+
}
1846+
// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
1847+
// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
1848+
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
1849+
// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
1850+
// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
1851+
// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
1852+
// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
1853+
// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
1854+
// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
1855+
// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] fastmath<nnan,contract>
1856+
// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
1857+
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] fastmath<nnan,contract>
1858+
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
1859+
// CHECK: return %[[RESULT]]
1860+
1861+
// -----
1862+
1863+
// CHECK-LABEL: func @complex_sin_with_fmf
1864+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
1865+
func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
1866+
%cos = complex.sin %arg fastmath<nnan,contract> : complex<f32>
1867+
return %cos : complex<f32>
1868+
}
1869+
// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
1870+
// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
1871+
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
1872+
// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
1873+
// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
1874+
// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
1875+
// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
1876+
// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
1877+
// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
1878+
// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] fastmath<nnan,contract>
1879+
// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
1880+
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
1881+
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
1882+
// CHECK: return %[[RESULT]]

0 commit comments

Comments
 (0)