@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196196 auto loc = op.getLoc ();
197197 auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
198198 auto elementType = cast<FloatType>(type.getElementType ());
199+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
199200
200201 Value real =
201202 rewriter.create <complex ::ReOp>(loc, elementType, adaptor.getComplex ());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
207208 // implementation in the subclass to combine them.
208209 Value half = rewriter.create <arith::ConstantOp>(
209210 loc, elementType, rewriter.getFloatAttr (elementType, 0.5 ));
210- Value exp = rewriter.create <math::ExpOp>(loc, imag);
211- Value scaledExp = rewriter.create <arith::MulFOp>(loc, half, exp);
212- Value reciprocalExp = rewriter.create <arith::DivFOp>(loc, half, exp);
213- Value sin = rewriter.create <math::SinOp>(loc, real);
214- Value cos = rewriter.create <math::CosOp>(loc, real);
211+ Value exp = rewriter.create <math::ExpOp>(loc, imag, fmf );
212+ Value scaledExp = rewriter.create <arith::MulFOp>(loc, half, exp, fmf );
213+ Value reciprocalExp = rewriter.create <arith::DivFOp>(loc, half, exp, fmf );
214+ Value sin = rewriter.create <math::SinOp>(loc, real, fmf );
215+ Value cos = rewriter.create <math::CosOp>(loc, real, fmf );
215216
216217 auto resultPair =
217- combine (loc, scaledExp, reciprocalExp, sin, cos, rewriter);
218+ combine (loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf );
218219
219220 rewriter.replaceOpWithNewOp <complex ::CreateOp>(op, type, resultPair.first ,
220221 resultPair.second );
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
223224
224225 virtual std::pair<Value, Value>
225226 combine (Location loc, Value scaledExp, Value reciprocalExp, Value sin,
226- Value cos, ConversionPatternRewriter &rewriter) const = 0 ;
227+ Value cos, ConversionPatternRewriter &rewriter,
228+ arith::FastMathFlagsAttr fmf) const = 0 ;
227229};
228230
229231struct CosOpConversion : public TrigonometricOpConversion <complex ::CosOp> {
230232 using TrigonometricOpConversion<complex ::CosOp>::TrigonometricOpConversion;
231233
232- std::pair<Value, Value>
233- combine (Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234- Value cos, ConversionPatternRewriter &rewriter) const override {
234+ std::pair<Value, Value> combine (Location loc, Value scaledExp,
235+ Value reciprocalExp, Value sin, Value cos,
236+ ConversionPatternRewriter &rewriter,
237+ arith::FastMathFlagsAttr fmf) const override {
235238 // Complex cosine is defined as;
236239 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
237240 // Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
241244 // We get:
242245 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
243246 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
244- Value sum = rewriter.create <arith::AddFOp>(loc, reciprocalExp, scaledExp);
245- Value resultReal = rewriter.create <arith::MulFOp>(loc, sum, cos);
246- Value diff = rewriter.create <arith::SubFOp>(loc, reciprocalExp, scaledExp);
247- Value resultImag = rewriter.create <arith::MulFOp>(loc, diff, sin);
247+ Value sum =
248+ rewriter.create <arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
249+ Value resultReal = rewriter.create <arith::MulFOp>(loc, sum, cos, fmf);
250+ Value diff =
251+ rewriter.create <arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
252+ Value resultImag = rewriter.create <arith::MulFOp>(loc, diff, sin, fmf);
248253 return {resultReal, resultImag};
249254 }
250255};
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
813818struct SinOpConversion : public TrigonometricOpConversion <complex ::SinOp> {
814819 using TrigonometricOpConversion<complex ::SinOp>::TrigonometricOpConversion;
815820
816- std::pair<Value, Value>
817- combine (Location loc, Value scaledExp, Value reciprocalExp, Value sin,
818- Value cos, ConversionPatternRewriter &rewriter) const override {
821+ std::pair<Value, Value> combine (Location loc, Value scaledExp,
822+ Value reciprocalExp, Value sin, Value cos,
823+ ConversionPatternRewriter &rewriter,
824+ arith::FastMathFlagsAttr fmf) const override {
819825 // Complex sine is defined as;
820826 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
821827 // Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
825831 // We get:
826832 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
827833 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
828- Value sum = rewriter.create <arith::AddFOp>(loc, scaledExp, reciprocalExp);
829- Value resultReal = rewriter.create <arith::MulFOp>(loc, sum, sin);
830- Value diff = rewriter.create <arith::SubFOp>(loc, scaledExp, reciprocalExp);
831- Value resultImag = rewriter.create <arith::MulFOp>(loc, diff, cos);
834+ Value sum =
835+ rewriter.create <arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
836+ Value resultReal = rewriter.create <arith::MulFOp>(loc, sum, sin, fmf);
837+ Value diff =
838+ rewriter.create <arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
839+ Value resultImag = rewriter.create <arith::MulFOp>(loc, diff, cos, fmf);
832840 return {resultReal, resultImag};
833841 }
834842};
0 commit comments