@@ -548,15 +548,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
548548// Operator Folders.
549549// ===----------------------------------------------------------------------===//
550550
551- template <typename IntFolder, typename FloatFolder>
551+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
552552DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
553553 RankedTensorType returnTy) {
554- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
555- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
556- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
557- if (lETy != rETy)
558- return {};
554+ if (!rhs || !lhs)
555+ return {};
556+
557+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
558+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
559+ if (lETy != rETy)
560+ return {};
561+
562+ if (!lETy.isIntOrFloat ())
563+ return {};
559564
565+ if (rhs.isSplat () && lhs.isSplat ()) {
560566 if (llvm::isa<IntegerType>(lETy)) {
561567 APInt l = lhs.getSplatValue <APInt>();
562568 APInt r = rhs.getSplatValue <APInt>();
@@ -572,9 +578,60 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
572578 }
573579 }
574580
581+ auto lhsCount = lhs.getNumElements ();
582+ auto rhsCount = rhs.getNumElements ();
583+ if (lhsCount != rhsCount) {
584+ return {};
585+ }
586+
587+ const int64_t MAX_ELEMENT_COUNT = 128 ;
588+ if (lhsCount > MAX_ELEMENT_COUNT) {
589+ // to prevent long compile time, skip if too many elements
590+ return {};
591+ }
592+
593+ if (llvm::isa<IntegerType>(lETy)) {
594+ auto lvalues = lhs.getValues <APInt>();
595+ auto rvalues = rhs.getValues <APInt>();
596+ SmallVector<APInt> results;
597+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
598+ auto result = IntFolder ()(l, r);
599+ results.push_back (result);
600+ }
601+ return DenseElementsAttr::get (returnTy, results);
602+ }
603+
604+ if (llvm::isa<FloatType>(lETy)) {
605+ auto lvalues = lhs.getValues <APFloat>();
606+ auto rvalues = rhs.getValues <APFloat>();
607+ // FloatFolder() may return either APFloat or APInt (comparison functions)
608+ SmallVector<FloatResultAPType> results;
609+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
610+ auto result = FloatFolder ()(l, r);
611+ results.push_back (result);
612+ }
613+ return DenseElementsAttr::get (returnTy, results);
614+ }
615+
575616 return {};
576617}
577618
619+ template <typename IntFolder, typename FloatFolder>
620+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
621+ DenseElementsAttr rhs,
622+ RankedTensorType returnTy) {
623+ // comparison FloatFolder() functions return APInt values
624+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
625+ }
626+
627+ template <typename IntFolder, typename FloatFolder>
628+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
629+ DenseElementsAttr rhs,
630+ RankedTensorType returnTy) {
631+ // arithmetic FloatFolder() functions return APFloat values
632+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
633+ }
634+
578635static bool isSplatZero (Type elemType, DenseElementsAttr val) {
579636 if (llvm::isa<FloatType>(elemType))
580637 return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -621,8 +678,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
621678 if (!lhsAttr || !rhsAttr)
622679 return {};
623680
624- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
625- resultTy);
681+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
682+ lhsAttr, rhsAttr, resultTy);
626683}
627684
628685OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -679,32 +736,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
679736}
680737
681738namespace {
739+
740+ // calculate lhs * rhs >> shift according to TOSA Spec
741+ // return nullopt if result is not in range of int32_t when shift > 0
742+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
743+ unsigned bitwidth) {
744+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
745+
746+ if (shift > 0 ) {
747+ auto round = APInt (64 , 1 ) << (shift - 1 );
748+ result += round;
749+ result.ashrInPlace (shift);
750+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
751+ if (!(result.getSExtValue () >= INT32_MIN &&
752+ result.getSExtValue () <= INT32_MAX)) {
753+ // REQUIRE failed
754+ return std::nullopt ;
755+ }
756+ }
757+
758+ return result.trunc (bitwidth);
759+ }
760+
682761DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
683762 RankedTensorType ty, int32_t shift) {
684- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
685- if (llvm::isa<IntegerType>(ty.getElementType ())) {
686- APInt l = lhs.getSplatValue <APInt>();
687- APInt r = rhs.getSplatValue <APInt>();
763+ if (!lhs || !rhs)
764+ return {};
765+
766+ // REQUIRE(0 <= shift && shift <= 63);
767+ if (!(0 <= shift && shift <= 63 ))
768+ return {};
688769
689- if (shift == 0 ) {
690- return DenseElementsAttr::get (ty, l * r);
770+ auto elementType = ty.getElementType ();
771+ if (!elementType.isIntOrFloat ())
772+ return {};
773+
774+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
775+ // REQUIRE(in_t == int32_t || shift == 0);
776+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
777+ return {};
778+
779+ if (rhs.isSplat () && lhs.isSplat ()) {
780+ if (llvm::isa<IntegerType>(elementType)) {
781+ auto l = lhs.getSplatValue <APInt>();
782+ auto r = rhs.getSplatValue <APInt>();
783+
784+ if (auto result = mulInt (l, r, shift, bitwidth)) {
785+ return DenseElementsAttr::get (ty, result.value ());
691786 }
787+ // mulInt failed
788+ return {};
789+ }
692790
693- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
694- l = l. sext (bitwidth * 2 );
695- r = r. sext (bitwidth * 2 );
791+ if (llvm::isa<FloatType>(elementType)) {
792+ auto l = lhs. getSplatValue <APFloat>( );
793+ auto r = rhs. getSplatValue <APFloat>( );
696794 auto result = l * r;
697- result.lshrInPlace (shift);
698- result = result.trunc (bitwidth);
699795 return DenseElementsAttr::get (ty, result);
700796 }
797+ }
701798
702- if (llvm::isa<FloatType>(ty. getElementType () )) {
703- APFloat l = lhs.getSplatValue <APFloat >();
704- APFloat r = rhs.getSplatValue <APFloat >();
705- APFloat result = l * r;
706- return DenseElementsAttr::get (ty, result) ;
799+ if (llvm::isa<IntegerType>(elementType )) {
800+ auto lvalues = lhs.getValues <APInt >();
801+ auto rvalues = rhs.getValues <APInt >();
802+ if (lvalues. size () != rvalues. size ()) {
803+ return {} ;
707804 }
805+ SmallVector<APInt> results;
806+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
807+ if (auto result = mulInt (l, r, shift, bitwidth)) {
808+ results.push_back (result.value ());
809+ continue ;
810+ }
811+ // mulInt failed
812+ return {};
813+ }
814+ return DenseElementsAttr::get (ty, results);
815+ }
816+
817+ if (llvm::isa<FloatType>(elementType)) {
818+ auto lvalues = lhs.getValues <APFloat>();
819+ auto rvalues = rhs.getValues <APFloat>();
820+ if (lvalues.size () != rvalues.size ()) {
821+ return {};
822+ }
823+ SmallVector<APFloat> results;
824+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
825+ auto result = l * r;
826+ results.push_back (result);
827+ }
828+ return DenseElementsAttr::get (ty, results);
708829 }
709830
710831 return {};
@@ -779,8 +900,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
779900 if (!lhsAttr || !rhsAttr)
780901 return {};
781902
782- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
783- resultTy);
903+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
904+ lhsAttr, rhsAttr, resultTy);
784905}
785906
786907namespace {
@@ -821,7 +942,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
821942 if (!lhsAttr || !rhsAttr)
822943 return {};
823944
824- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
945+ return comparisonBinaryFolder<APIntFoldGreater,
946+ ComparisonFold<std::greater<APFloat>>>(
825947 lhsAttr, rhsAttr, resultTy);
826948}
827949
@@ -835,8 +957,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
835957 if (!lhsAttr || !rhsAttr)
836958 return {};
837959
838- return binaryFolder <APIntFoldGreaterEqual,
839- ComparisonFold<std::greater_equal<APFloat>>>(
960+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
961+ ComparisonFold<std::greater_equal<APFloat>>>(
840962 lhsAttr, rhsAttr, resultTy);
841963}
842964
@@ -860,9 +982,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
860982 if (!lhsAttr || !rhsAttr)
861983 return {};
862984
863- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
864- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
865- resultTy);
985+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
986+ ComparisonFold<std::equal_to<APFloat>>>(
987+ lhsAttr, rhsAttr, resultTy);
866988}
867989
868990OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments