@@ -584,6 +584,135 @@ bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type outTy, Type inTy,
584584 return allowOnFloatingPointMath || guaranteedNoNanResult (op, rewriter);
585585}
586586
587+ bool canApplySymmetricPattern (mlir::Operation *op, PatternRewriter &rewriter) {
588+ return guaranteedSymmetricResult (op, rewriter);
589+ }
590+
591+ SymmetricResultAnalysis initSymmetricResultAnalysis () {
592+ return SymmetricResultAnalysis ();
593+ }
594+
595+ bool SymmetricResultAnalysis::constantIntCheck (DenseElementsAttr attr) {
596+ return false ; // TODO
597+ }
598+
599+ bool SymmetricResultAnalysis::constantFloatCheck (DenseElementsAttr attr) {
600+ return false ; // TODO
601+ }
602+
603+ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed (
604+ Operation *op, SmallVectorImpl<Operation *> &localtodo,
605+ PatternRewriter &rewriter) {
606+ assert (op);
607+
608+ auto outTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
609+ if (outTy.getRank () != 2 )
610+ return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
611+ if (outTy.getDimSize (0 ) != outTy.getDimSize (1 ))
612+ return State::NOTGUARANTEED; // quick check and exit
613+
614+ SplatElementsAttr splatAttr;
615+ if (matchPattern (op, m_Constant (&splatAttr))) {
616+ return State::GUARANTEED;
617+ }
618+
619+ DenseElementsAttr denseAttr;
620+ if (matchPattern (op, m_Constant (&denseAttr))) {
621+ if (guaranteedConstantOp (op, denseAttr, rewriter)) {
622+ return State::GUARANTEED;
623+ } else {
624+ return State::NOTGUARANTEED;
625+ }
626+ }
627+
628+ // check that transpose dimensions are [1,0]
629+ auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool {
630+ auto perm = tOp.getPermutation ();
631+ return perm.size () == 2 && perm[0 ] == 1 && perm[1 ] == 0 ;
632+ };
633+
634+ // TODO: check for dot_general as well
635+
636+ // commutative operation with A and A^T will always be symmetric
637+ // op(A, A^T) will also always be symmetric
638+ if (stablehlo::hasTraitElementwise (op) &&
639+ (op->hasTrait <OpTrait::IsCommutative>() ||
640+ op->hasTrait <hlo::OpTrait::IsCommutative>())) {
641+ auto lhs = op->getOperand (0 );
642+ auto rhs = op->getOperand (1 );
643+
644+ // op(A, A^T)
645+ if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
646+ if (isTrueTranspose (rhsT)) {
647+ if (lhs == rhsT.getOperand ()) {
648+ return State::GUARANTEED;
649+ }
650+ }
651+ }
652+
653+ // op(A^T, A)
654+ if (auto lhsT = lhs.getDefiningOp <stablehlo::TransposeOp>()) {
655+ if (isTrueTranspose (lhsT)) {
656+ if (rhs == lhsT.getOperand ()) {
657+ return State::GUARANTEED;
658+ }
659+ }
660+ }
661+ }
662+
663+ bool recursiveCheck = false ;
664+
665+ // elementwise ops
666+ if (stablehlo::hasTraitElementwise (op)) {
667+ recursiveCheck = true ;
668+ }
669+
670+ /* *
671+ * TODO
672+ * - check if its * 0 -> symmetric
673+ */
674+
675+ if (recursiveCheck) {
676+ bool allOperandsGuaranteed = true ;
677+ for (auto operand : op->getOperands ()) {
678+ {
679+ auto found = valueCache.find (operand);
680+ if (found != valueCache.end ()) {
681+ if (found->second ) {
682+ continue ;
683+ } else {
684+ return State::NOTGUARANTEED;
685+ }
686+ }
687+ }
688+ auto dop = operand.getDefiningOp ();
689+ if (!dop)
690+ return State::NOTGUARANTEED;
691+
692+ {
693+ auto found = opCache.find (dop);
694+ if (found != opCache.end ()) {
695+ if (found->second ) {
696+ continue ;
697+ } else {
698+ return State::NOTGUARANTEED;
699+ }
700+ }
701+ }
702+
703+ localtodo.push_back (dop);
704+ allOperandsGuaranteed = false ;
705+ }
706+
707+ if (allOperandsGuaranteed)
708+ return State::GUARANTEED;
709+ else
710+ return State::PENDING;
711+ } else {
712+ return State::NOTGUARANTEED;
713+ }
714+ }
715+
587716NoNanResultAnalysis initNoNanResultAnalysis () {
588717 auto finiteAnalysis = std::make_shared<FiniteResultAnalysis>();
589718 auto noNanAnalysis = std::make_shared<NoNanResultAnalysis>();
@@ -617,8 +746,9 @@ NoNanResultAnalysis::localGuaranteed(Operation *op,
617746 return State::NOTGUARANTEED;
618747 }
619748
620- if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
621- if (guaranteed (constantOp, rewriter)) {
749+ DenseElementsAttr denseAttr;
750+ if (matchPattern (op, m_Constant (&denseAttr))) {
751+ if (guaranteedConstantOp (op, denseAttr, rewriter)) {
622752 return State::GUARANTEED;
623753 } else {
624754 return State::NOTGUARANTEED;
@@ -759,8 +889,9 @@ FiniteResultAnalysis::localGuaranteed(Operation *op,
759889 return State::NOTGUARANTEED;
760890 }
761891
762- if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
763- if (guaranteed (constantOp, rewriter)) {
892+ DenseElementsAttr denseAttr;
893+ if (matchPattern (op, m_Constant (&denseAttr))) {
894+ if (guaranteedConstantOp (op, denseAttr, rewriter)) {
764895 return State::GUARANTEED;
765896 } else {
766897 return State::NOTGUARANTEED;
@@ -871,8 +1002,9 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
8711002 return State::NOTGUARANTEED;
8721003 }
8731004
874- if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
875- if (guaranteed (constantOp, rewriter)) {
1005+ DenseElementsAttr denseAttr;
1006+ if (matchPattern (op, m_Constant (&denseAttr))) {
1007+ if (guaranteedConstantOp (op, denseAttr, rewriter)) {
8761008 return State::GUARANTEED;
8771009 } else {
8781010 return State::NOTGUARANTEED;
0 commit comments