@@ -548,15 +548,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
548548// Operator Folders.
549549// ===----------------------------------------------------------------------===//
550550
551- template <typename IntFolder, typename FloatFolder>
551+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
552552DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
553553 RankedTensorType returnTy) {
554- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
555- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
556- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
557- if (lETy != rETy)
558- return {};
554+ if (!rhs || !lhs)
555+ return {};
556+
557+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
558+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
559+ if (lETy != rETy)
560+ return {};
561+
562+ if (!lETy.isIntOrFloat ())
563+ return {};
559564
565+ if (rhs.isSplat () && lhs.isSplat ()) {
560566 if (llvm::isa<IntegerType>(lETy)) {
561567 APInt l = lhs.getSplatValue <APInt>();
562568 APInt r = rhs.getSplatValue <APInt>();
@@ -572,9 +578,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
572578 }
573579 }
574580
581+ auto lhsCount = lhs.getNumElements ();
582+ auto rhsCount = rhs.getNumElements ();
583+ if (lhsCount != rhsCount)
584+ return {};
585+
586+ // to prevent long compile time, skip if too many elements
587+ if (lhsCount > 128 )
588+ return {};
589+
590+ if (llvm::isa<IntegerType>(lETy)) {
591+ auto lvalues = lhs.getValues <APInt>();
592+ auto rvalues = rhs.getValues <APInt>();
593+ SmallVector<APInt> results;
594+ IntFolder intFolder{};
595+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
596+ auto result = intFolder (l, r);
597+ results.push_back (result);
598+ }
599+ return DenseElementsAttr::get (returnTy, results);
600+ }
601+
602+ if (llvm::isa<FloatType>(lETy)) {
603+ auto lvalues = lhs.getValues <APFloat>();
604+ auto rvalues = rhs.getValues <APFloat>();
605+ // FloatFolder() may return either APFloat or APInt (comparison functions)
606+ SmallVector<FloatResultAPType> results;
607+ FloatFolder floatFolder{};
608+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
609+ auto result = floatFolder (l, r);
610+ results.push_back (result);
611+ }
612+ return DenseElementsAttr::get (returnTy, results);
613+ }
614+
575615 return {};
576616}
577617
618+ template <typename IntFolder, typename FloatFolder>
619+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
620+ DenseElementsAttr rhs,
621+ RankedTensorType returnTy) {
622+ // comparison FloatFolder() functions return APInt values
623+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
624+ }
625+
626+ template <typename IntFolder, typename FloatFolder>
627+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
628+ DenseElementsAttr rhs,
629+ RankedTensorType returnTy) {
630+ // arithmetic FloatFolder() functions return APFloat values
631+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
632+ }
633+
578634static bool isSplatZero (Type elemType, DenseElementsAttr val) {
579635 if (llvm::isa<FloatType>(elemType))
580636 return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -621,8 +677,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
621677 if (!lhsAttr || !rhsAttr)
622678 return {};
623679
624- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
625- resultTy);
680+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
681+ lhsAttr, rhsAttr, resultTy);
626682}
627683
628684OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -679,32 +735,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
679735}
680736
681737namespace {
738+
739+ // calculate lhs * rhs >> shift according to TOSA Spec
740+ // return nullopt if result is not in range of int32_t when shift > 0
741+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
742+ unsigned bitwidth) {
743+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
744+
745+ if (shift > 0 ) {
746+ auto round = APInt (64 , 1 ) << (shift - 1 );
747+ result += round;
748+ result.ashrInPlace (shift);
749+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
750+ if (!(result.getSExtValue () >= INT32_MIN &&
751+ result.getSExtValue () <= INT32_MAX)) {
752+ // REQUIRE failed
753+ return std::nullopt ;
754+ }
755+ }
756+
757+ return result.trunc (bitwidth);
758+ }
759+
682760DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
683761 RankedTensorType ty, int32_t shift) {
684- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
685- if (llvm::isa<IntegerType>(ty.getElementType ())) {
686- APInt l = lhs.getSplatValue <APInt>();
687- APInt r = rhs.getSplatValue <APInt>();
762+ if (!lhs || !rhs)
763+ return {};
764+
765+ // REQUIRE(0 <= shift && shift <= 63);
766+ if (!(0 <= shift && shift <= 63 ))
767+ return {};
768+
769+ auto elementType = ty.getElementType ();
770+ if (!elementType.isIntOrFloat ())
771+ return {};
688772
689- if (shift == 0 ) {
690- return DenseElementsAttr::get (ty, l * r);
773+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
774+ // REQUIRE(in_t == int32_t || shift == 0);
775+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
776+ return {};
777+
778+ if (rhs.isSplat () && lhs.isSplat ()) {
779+ if (llvm::isa<IntegerType>(elementType)) {
780+ auto l = lhs.getSplatValue <APInt>();
781+ auto r = rhs.getSplatValue <APInt>();
782+
783+ if (auto result = mulInt (l, r, shift, bitwidth)) {
784+ return DenseElementsAttr::get (ty, result.value ());
691785 }
786+ // mulInt failed
787+ return {};
788+ }
692789
693- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
694- l = l. sext (bitwidth * 2 );
695- r = r. sext (bitwidth * 2 );
790+ if (llvm::isa<FloatType>(elementType)) {
791+ auto l = lhs. getSplatValue <APFloat>( );
792+ auto r = rhs. getSplatValue <APFloat>( );
696793 auto result = l * r;
697- result.lshrInPlace (shift);
698- result = result.trunc (bitwidth);
699794 return DenseElementsAttr::get (ty, result);
700795 }
796+ }
701797
702- if (llvm::isa<FloatType>(ty.getElementType ())) {
703- APFloat l = lhs.getSplatValue <APFloat>();
704- APFloat r = rhs.getSplatValue <APFloat>();
705- APFloat result = l * r;
706- return DenseElementsAttr::get (ty, result);
798+ if (llvm::isa<IntegerType>(elementType)) {
799+ auto lvalues = lhs.getValues <APInt>();
800+ auto rvalues = rhs.getValues <APInt>();
801+ if (lvalues.size () != rvalues.size ()) {
802+ return {};
803+ }
804+ SmallVector<APInt> results;
805+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
806+ if (auto result = mulInt (l, r, shift, bitwidth)) {
807+ results.push_back (result.value ());
808+ continue ;
809+ }
810+ // mulInt failed
811+ return {};
812+ }
813+ return DenseElementsAttr::get (ty, results);
814+ }
815+
816+ if (llvm::isa<FloatType>(elementType)) {
817+ auto lvalues = lhs.getValues <APFloat>();
818+ auto rvalues = rhs.getValues <APFloat>();
819+ if (lvalues.size () != rvalues.size ()) {
820+ return {};
707821 }
822+ SmallVector<APFloat> results;
823+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
824+ auto result = l * r;
825+ results.push_back (result);
826+ }
827+ return DenseElementsAttr::get (ty, results);
708828 }
709829
710830 return {};
@@ -779,8 +899,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
779899 if (!lhsAttr || !rhsAttr)
780900 return {};
781901
782- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
783- resultTy);
902+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
903+ lhsAttr, rhsAttr, resultTy);
784904}
785905
786906namespace {
@@ -821,7 +941,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
821941 if (!lhsAttr || !rhsAttr)
822942 return {};
823943
824- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
944+ return comparisonBinaryFolder<APIntFoldGreater,
945+ ComparisonFold<std::greater<APFloat>>>(
825946 lhsAttr, rhsAttr, resultTy);
826947}
827948
@@ -835,8 +956,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
835956 if (!lhsAttr || !rhsAttr)
836957 return {};
837958
838- return binaryFolder <APIntFoldGreaterEqual,
839- ComparisonFold<std::greater_equal<APFloat>>>(
959+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
960+ ComparisonFold<std::greater_equal<APFloat>>>(
840961 lhsAttr, rhsAttr, resultTy);
841962}
842963
@@ -860,9 +981,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
860981 if (!lhsAttr || !rhsAttr)
861982 return {};
862983
863- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
864- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
865- resultTy);
984+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
985+ ComparisonFold<std::equal_to<APFloat>>>(
986+ lhsAttr, rhsAttr, resultTy);
866987}
867988
868989OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments