@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
563563// Operator Folders.
564564// ===----------------------------------------------------------------------===//
565565
566- template <typename IntFolder, typename FloatFolder>
566+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
567567DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
568568 RankedTensorType returnTy) {
569- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
570- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
571- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
572- if (lETy != rETy)
573- return {};
569+ if (!rhs || !lhs)
570+ return {};
571+
572+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
573+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
574+ if (lETy != rETy)
575+ return {};
576+
577+ if (!lETy.isIntOrFloat ())
578+ return {};
574579
580+ if (rhs.isSplat () && lhs.isSplat ()) {
575581 if (llvm::isa<IntegerType>(lETy)) {
576582 APInt l = lhs.getSplatValue <APInt>();
577583 APInt r = rhs.getSplatValue <APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
587593 }
588594 }
589595
596+ if (llvm::isa<IntegerType>(lETy)) {
597+ auto lvalues = lhs.getValues <APInt>();
598+ auto rvalues = rhs.getValues <APInt>();
599+ if (lvalues.size () != rvalues.size ()) {
600+ return {};
601+ }
602+ SmallVector<APInt> results;
603+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
604+ auto result = IntFolder ()(l, r);
605+ results.push_back (result);
606+ }
607+ return DenseElementsAttr::get (returnTy, results);
608+ }
609+
610+ if (llvm::isa<FloatType>(lETy)) {
611+ auto lvalues = lhs.getValues <APFloat>();
612+ auto rvalues = rhs.getValues <APFloat>();
613+ if (lvalues.size () != rvalues.size ()) {
614+ return {};
615+ }
616+ // FloatFolder() may return either APFloat or APInt (comparison functions)
617+ SmallVector<FloatResultAPType> results;
618+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
619+ auto result = FloatFolder ()(l, r);
620+ results.push_back (result);
621+ }
622+ return DenseElementsAttr::get (returnTy, results);
623+ }
624+
590625 return {};
591626}
592627
628+ template <typename IntFolder, typename FloatFolder>
629+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
630+ DenseElementsAttr rhs,
631+ RankedTensorType returnTy) {
632+ // comparison FloatFolder() functions return APInt values
633+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
634+ }
635+
636+ template <typename IntFolder, typename FloatFolder>
637+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
638+ DenseElementsAttr rhs,
639+ RankedTensorType returnTy) {
640+ // arithmetic FloatFolder() functions return APFloat values
641+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
642+ }
643+
593644static bool isSplatZero (Type elemType, DenseElementsAttr val) {
594645 if (llvm::isa<FloatType>(elemType))
595646 return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
636687 if (!lhsAttr || !rhsAttr)
637688 return {};
638689
639- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
640- resultTy);
690+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
691+ lhsAttr, rhsAttr, resultTy);
641692}
642693
643694OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
693744}
694745
695746namespace {
747+
748+ // calculate lhs * rhs >> shift according to TOSA Spec
749+ // return nullopt if result is not in range of int32_t when shift > 0
750+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
751+ unsigned bitwidth) {
752+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
753+
754+ if (shift > 0 ) {
755+ auto round = APInt (64 , 1 ) << (shift - 1 );
756+ result += round;
757+ result.ashrInPlace (shift);
758+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
759+ if (!(result.getSExtValue () >= INT32_MIN &&
760+ result.getSExtValue () <= INT32_MAX)) {
761+ // REQUIRE failed
762+ return std::nullopt ;
763+ }
764+ }
765+
766+ return result.trunc (bitwidth);
767+ }
768+
696769DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
697770 RankedTensorType ty, int32_t shift) {
698- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
699- if (llvm::isa<IntegerType>(ty.getElementType ())) {
700- APInt l = lhs.getSplatValue <APInt>();
701- APInt r = rhs.getSplatValue <APInt>();
771+ if (!lhs || !rhs)
772+ return {};
773+
774+ // REQUIRE(0 <= shift && shift <= 63);
775+ if (!(0 <= shift && shift <= 63 ))
776+ return {};
702777
703- if (shift == 0 ) {
704- return DenseElementsAttr::get (ty, l * r);
778+ auto elementType = ty.getElementType ();
779+ if (!elementType.isIntOrFloat ())
780+ return {};
781+
782+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
783+ // REQUIRE(in_t == int32_t || shift == 0);
784+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
785+ return {};
786+
787+ if (rhs.isSplat () && lhs.isSplat ()) {
788+ if (llvm::isa<IntegerType>(elementType)) {
789+ auto l = lhs.getSplatValue <APInt>();
790+ auto r = rhs.getSplatValue <APInt>();
791+
792+ if (auto result = mulInt (l, r, shift, bitwidth)) {
793+ return DenseElementsAttr::get (ty, result.value ());
705794 }
795+ // mulInt failed
796+ return {};
797+ }
706798
707- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
708- l = l. sext (bitwidth * 2 );
709- r = r. sext (bitwidth * 2 );
799+ if (llvm::isa<FloatType>(elementType)) {
800+ auto l = lhs. getSplatValue <APFloat>( );
801+ auto r = rhs. getSplatValue <APFloat>( );
710802 auto result = l * r;
711- result.lshrInPlace (shift);
712- result = result.trunc (bitwidth);
713803 return DenseElementsAttr::get (ty, result);
714804 }
805+ }
806+
807+ if (llvm::isa<IntegerType>(elementType)) {
808+ auto lvalues = lhs.getValues <APInt>();
809+ auto rvalues = rhs.getValues <APInt>();
810+ if (lvalues.size () != rvalues.size ()) {
811+ return {};
812+ }
813+ SmallVector<APInt> results;
814+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
815+ if (auto result = mulInt (l, r, shift, bitwidth)) {
816+ results.push_back (result.value ());
817+ continue ;
818+ }
819+ // mulInt failed
820+ return {};
821+ }
822+ return DenseElementsAttr::get (ty, results);
823+ }
715824
716- if (llvm::isa<FloatType>(ty. getElementType () )) {
717- APFloat l = lhs.getSplatValue <APFloat>();
718- APFloat r = rhs.getSplatValue <APFloat>();
719- APFloat result = l * r;
720- return DenseElementsAttr::get (ty, result) ;
825+ if (llvm::isa<FloatType>(elementType )) {
826+ auto lvalues = lhs.getValues <APFloat>();
827+ auto rvalues = rhs.getValues <APFloat>();
828+ if (lvalues. size () != rvalues. size ()) {
829+ return {} ;
721830 }
831+ SmallVector<APFloat> results;
832+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
833+ auto result = l * r;
834+ results.push_back (result);
835+ }
836+ return DenseElementsAttr::get (ty, results);
722837 }
723838
724839 return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
793908 if (!lhsAttr || !rhsAttr)
794909 return {};
795910
796- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
797- resultTy);
911+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
912+ lhsAttr, rhsAttr, resultTy);
798913}
799914
800915namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
835950 if (!lhsAttr || !rhsAttr)
836951 return {};
837952
838- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
953+ return comparisonBinaryFolder<APIntFoldGreater,
954+ ComparisonFold<std::greater<APFloat>>>(
839955 lhsAttr, rhsAttr, resultTy);
840956}
841957
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
849965 if (!lhsAttr || !rhsAttr)
850966 return {};
851967
852- return binaryFolder <APIntFoldGreaterEqual,
853- ComparisonFold<std::greater_equal<APFloat>>>(
968+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
969+ ComparisonFold<std::greater_equal<APFloat>>>(
854970 lhsAttr, rhsAttr, resultTy);
855971}
856972
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
874990 if (!lhsAttr || !rhsAttr)
875991 return {};
876992
877- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
878- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
879- resultTy);
993+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
994+ ComparisonFold<std::equal_to<APFloat>>>(
995+ lhsAttr, rhsAttr, resultTy);
880996}
881997
882998OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments