@@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
501501// Operator Folders.
502502// ===----------------------------------------------------------------------===//
503503
504- template <typename IntFolder, typename FloatFolder>
504+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
505505DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
506506 RankedTensorType returnTy) {
507- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
508- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
509- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
510- if (lETy != rETy)
511- return {};
507+ if (!rhs || !lhs)
508+ return {};
509+
510+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
511+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
512+ if (lETy != rETy)
513+ return {};
514+
515+ if (!lETy.isIntOrFloat ())
516+ return {};
512517
518+ if (rhs.isSplat () && lhs.isSplat ()) {
513519 if (llvm::isa<IntegerType>(lETy)) {
514520 APInt l = lhs.getSplatValue <APInt>();
515521 APInt r = rhs.getSplatValue <APInt>();
@@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
525531 }
526532 }
527533
534+ auto lhsCount = lhs.getNumElements ();
535+ auto rhsCount = rhs.getNumElements ();
536+ if (lhsCount != rhsCount)
537+ return {};
538+
539+ // to prevent long compile time, skip if too many elements
540+ if (lhsCount > 128 )
541+ return {};
542+
543+ if (llvm::isa<IntegerType>(lETy)) {
544+ auto lvalues = lhs.getValues <APInt>();
545+ auto rvalues = rhs.getValues <APInt>();
546+ SmallVector<APInt> results;
547+ IntFolder intFolder{};
548+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
549+ auto result = intFolder (l, r);
550+ results.push_back (result);
551+ }
552+ return DenseElementsAttr::get (returnTy, results);
553+ }
554+
555+ if (llvm::isa<FloatType>(lETy)) {
556+ auto lvalues = lhs.getValues <APFloat>();
557+ auto rvalues = rhs.getValues <APFloat>();
558+ // FloatFolder() may return either APFloat or APInt (comparison functions)
559+ SmallVector<FloatResultAPType> results;
560+ FloatFolder floatFolder{};
561+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
562+ auto result = floatFolder (l, r);
563+ results.push_back (result);
564+ }
565+ return DenseElementsAttr::get (returnTy, results);
566+ }
567+
528568 return {};
529569}
530570
571+ template <typename IntFolder, typename FloatFolder>
572+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
573+ DenseElementsAttr rhs,
574+ RankedTensorType returnTy) {
575+ // comparison FloatFolder() functions return APInt values
576+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
577+ }
578+
579+ template <typename IntFolder, typename FloatFolder>
580+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
581+ DenseElementsAttr rhs,
582+ RankedTensorType returnTy) {
583+ // arithmetic FloatFolder() functions return APFloat values
584+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
585+ }
586+
531587static bool isSplatZero (Type elemType, DenseElementsAttr val) {
532588 if (llvm::isa<FloatType>(elemType))
533589 return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
574630 if (!lhsAttr || !rhsAttr)
575631 return {};
576632
577- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
578- resultTy);
633+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
634+ lhsAttr, rhsAttr, resultTy);
579635}
580636
581637OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
632688}
633689
634690namespace {
691+
692+ // calculate lhs * rhs >> shift according to TOSA Spec
693+ // return nullopt if result is not in range of int32_t when shift > 0
694+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
695+ unsigned bitwidth) {
696+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
697+
698+ if (shift > 0 ) {
699+ auto round = APInt (64 , 1 ) << (shift - 1 );
700+ result += round;
701+ result.ashrInPlace (shift);
702+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
703+ if (!(result.getSExtValue () >= INT32_MIN &&
704+ result.getSExtValue () <= INT32_MAX)) {
705+ // REQUIRE failed
706+ return std::nullopt ;
707+ }
708+ }
709+
710+ return result.trunc (bitwidth);
711+ }
712+
635713DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
636714 RankedTensorType ty, int32_t shift) {
637- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
638- if (llvm::isa<IntegerType>(ty.getElementType ())) {
639- APInt l = lhs.getSplatValue <APInt>();
640- APInt r = rhs.getSplatValue <APInt>();
715+ if (!lhs || !rhs)
716+ return {};
717+
718+ // REQUIRE(0 <= shift && shift <= 63);
719+ if (!(0 <= shift && shift <= 63 ))
720+ return {};
721+
722+ auto elementType = ty.getElementType ();
723+ if (!elementType.isIntOrFloat ())
724+ return {};
641725
642- if (shift == 0 ) {
643- return DenseElementsAttr::get (ty, l * r);
726+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
727+ // REQUIRE(in_t == int32_t || shift == 0);
728+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
729+ return {};
730+
731+ if (rhs.isSplat () && lhs.isSplat ()) {
732+ if (llvm::isa<IntegerType>(elementType)) {
733+ auto l = lhs.getSplatValue <APInt>();
734+ auto r = rhs.getSplatValue <APInt>();
735+
736+ if (auto result = mulInt (l, r, shift, bitwidth)) {
737+ return DenseElementsAttr::get (ty, result.value ());
644738 }
739+ // mulInt failed
740+ return {};
741+ }
645742
646- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
647- l = l. sext (bitwidth * 2 );
648- r = r. sext (bitwidth * 2 );
743+ if (llvm::isa<FloatType>(elementType)) {
744+ auto l = lhs. getSplatValue <APFloat>( );
745+ auto r = rhs. getSplatValue <APFloat>( );
649746 auto result = l * r;
650- result.lshrInPlace (shift);
651- result = result.trunc (bitwidth);
652747 return DenseElementsAttr::get (ty, result);
653748 }
749+ }
654750
655- if (llvm::isa<FloatType>(ty.getElementType ())) {
656- APFloat l = lhs.getSplatValue <APFloat>();
657- APFloat r = rhs.getSplatValue <APFloat>();
658- APFloat result = l * r;
659- return DenseElementsAttr::get (ty, result);
751+ if (llvm::isa<IntegerType>(elementType)) {
752+ auto lvalues = lhs.getValues <APInt>();
753+ auto rvalues = rhs.getValues <APInt>();
754+ if (lvalues.size () != rvalues.size ()) {
755+ return {};
756+ }
757+ SmallVector<APInt> results;
758+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
759+ if (auto result = mulInt (l, r, shift, bitwidth)) {
760+ results.push_back (result.value ());
761+ continue ;
762+ }
763+ // mulInt failed
764+ return {};
765+ }
766+ return DenseElementsAttr::get (ty, results);
767+ }
768+
769+ if (llvm::isa<FloatType>(elementType)) {
770+ auto lvalues = lhs.getValues <APFloat>();
771+ auto rvalues = rhs.getValues <APFloat>();
772+ if (lvalues.size () != rvalues.size ()) {
773+ return {};
660774 }
775+ SmallVector<APFloat> results;
776+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
777+ auto result = l * r;
778+ results.push_back (result);
779+ }
780+ return DenseElementsAttr::get (ty, results);
661781 }
662782
663783 return {};
@@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
732852 if (!lhsAttr || !rhsAttr)
733853 return {};
734854
735- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
736- resultTy);
855+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
856+ lhsAttr, rhsAttr, resultTy);
737857}
738858
739859namespace {
@@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
774894 if (!lhsAttr || !rhsAttr)
775895 return {};
776896
777- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
897+ return comparisonBinaryFolder<APIntFoldGreater,
898+ ComparisonFold<std::greater<APFloat>>>(
778899 lhsAttr, rhsAttr, resultTy);
779900}
780901
@@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
788909 if (!lhsAttr || !rhsAttr)
789910 return {};
790911
791- return binaryFolder <APIntFoldGreaterEqual,
792- ComparisonFold<std::greater_equal<APFloat>>>(
912+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
913+ ComparisonFold<std::greater_equal<APFloat>>>(
793914 lhsAttr, rhsAttr, resultTy);
794915}
795916
@@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
813934 if (!lhsAttr || !rhsAttr)
814935 return {};
815936
816- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
817- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
818- resultTy);
937+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
938+ ComparisonFold<std::equal_to<APFloat>>>(
939+ lhsAttr, rhsAttr, resultTy);
819940}
820941
821942OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments