Skip to content

Commit ff3d879

Browse files
committed
fix complex mul
Signed-off-by: Benoit Jacob <[email protected]>
1 parent 03661fb commit ff3d879

File tree

2 files changed

+6
-688
lines changed

2 files changed

+6
-688
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 0 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -696,177 +696,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
696696
auto elementType = cast<FloatType>(type.getElementType());
697697
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698698
auto fmfValue = fmf.getValue();
699-
700699
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
701-
Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
702700
Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
703-
Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
704701
Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
705-
Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
706702
Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
707-
Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
708-
709703
Value lhsRealTimesRhsReal =
710704
b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
711-
Value lhsRealTimesRhsRealAbs =
712-
b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
713705
Value lhsImagTimesRhsImag =
714706
b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
715-
Value lhsImagTimesRhsImagAbs =
716-
b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
717707
Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
718708
lhsImagTimesRhsImag, fmfValue);
719-
720709
Value lhsImagTimesRhsReal =
721710
b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
722-
Value lhsImagTimesRhsRealAbs =
723-
b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
724711
Value lhsRealTimesRhsImag =
725712
b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
726-
Value lhsRealTimesRhsImagAbs =
727-
b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
728713
Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
729714
lhsRealTimesRhsImag, fmfValue);
730-
731-
// Handle cases where the "naive" calculation results in NaN values.
732-
Value realIsNan =
733-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
734-
Value imagIsNan =
735-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
736-
Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
737-
738-
Value inf = b.create<arith::ConstantOp>(
739-
elementType,
740-
b.getFloatAttr(elementType,
741-
APFloat::getInf(elementType.getFloatSemantics())));
742-
743-
// Case 1. `lhsReal` or `lhsImag` are infinite.
744-
Value lhsRealIsInf =
745-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
746-
Value lhsImagIsInf =
747-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
748-
Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
749-
Value rhsRealIsNan =
750-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
751-
Value rhsImagIsNan =
752-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
753-
Value zero =
754-
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
755-
Value one = b.create<arith::ConstantOp>(elementType,
756-
b.getFloatAttr(elementType, 1));
757-
Value lhsRealIsInfFloat =
758-
b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
759-
lhsReal = b.create<arith::SelectOp>(
760-
lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
761-
lhsReal);
762-
Value lhsImagIsInfFloat =
763-
b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
764-
lhsImag = b.create<arith::SelectOp>(
765-
lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
766-
lhsImag);
767-
Value lhsIsInfAndRhsRealIsNan =
768-
b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
769-
rhsReal = b.create<arith::SelectOp>(
770-
lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
771-
rhsReal);
772-
Value lhsIsInfAndRhsImagIsNan =
773-
b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
774-
rhsImag = b.create<arith::SelectOp>(
775-
lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
776-
rhsImag);
777-
778-
// Case 2. `rhsReal` or `rhsImag` are infinite.
779-
Value rhsRealIsInf =
780-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
781-
Value rhsImagIsInf =
782-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
783-
Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
784-
Value lhsRealIsNan =
785-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
786-
Value lhsImagIsNan =
787-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
788-
Value rhsRealIsInfFloat =
789-
b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
790-
rhsReal = b.create<arith::SelectOp>(
791-
rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
792-
rhsReal);
793-
Value rhsImagIsInfFloat =
794-
b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
795-
rhsImag = b.create<arith::SelectOp>(
796-
rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
797-
rhsImag);
798-
Value rhsIsInfAndLhsRealIsNan =
799-
b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
800-
lhsReal = b.create<arith::SelectOp>(
801-
rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
802-
lhsReal);
803-
Value rhsIsInfAndLhsImagIsNan =
804-
b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
805-
lhsImag = b.create<arith::SelectOp>(
806-
rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
807-
lhsImag);
808-
Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
809-
810-
// Case 3. One of the pairwise products of left hand side with right hand
811-
// side is infinite.
812-
Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
813-
arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
814-
Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
815-
arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
816-
Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
817-
lhsImagTimesRhsImagIsInf);
818-
Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
819-
arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
820-
isSpecialCase =
821-
b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
822-
Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
823-
arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
824-
isSpecialCase =
825-
b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
826-
Type i1Type = b.getI1Type();
827-
Value notRecalc = b.create<arith::XOrIOp>(
828-
recalc,
829-
b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
830-
isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
831-
Value isSpecialCaseAndLhsRealIsNan =
832-
b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
833-
lhsReal = b.create<arith::SelectOp>(
834-
isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
835-
lhsReal);
836-
Value isSpecialCaseAndLhsImagIsNan =
837-
b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
838-
lhsImag = b.create<arith::SelectOp>(
839-
isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
840-
lhsImag);
841-
Value isSpecialCaseAndRhsRealIsNan =
842-
b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
843-
rhsReal = b.create<arith::SelectOp>(
844-
isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
845-
rhsReal);
846-
Value isSpecialCaseAndRhsImagIsNan =
847-
b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
848-
rhsImag = b.create<arith::SelectOp>(
849-
isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
850-
rhsImag);
851-
recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
852-
recalc = b.create<arith::AndIOp>(isNan, recalc);
853-
854-
// Recalculate real part.
855-
lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
856-
lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
857-
Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
858-
lhsImagTimesRhsImag, fmfValue);
859-
real = b.create<arith::SelectOp>(
860-
recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
861-
862-
// Recalculate imag part.
863-
lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
864-
lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
865-
Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
866-
lhsRealTimesRhsImag, fmfValue);
867-
imag = b.create<arith::SelectOp>(
868-
recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
869-
870715
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
871716
return success();
872717
}

0 commit comments

Comments
 (0)