3939#include < iterator>
4040#include < mlir/IR/BuiltinAttributes.h>
4141#include < mlir/IR/Value.h>
42+ #include < optional>
4243#include < set>
4344
4445using namespace mlir ;
@@ -765,19 +766,21 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
765766 }
766767
767768 bool recursiveCheck = false ;
769+ SmallVector<Value> operandsToCheck;
768770
769771 if (isa<stablehlo::SliceOp, stablehlo::ConcatenateOp,
770772 stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
771773 stablehlo::TransposeOp>(op)) {
772774 // data movement ops
773775 recursiveCheck = true ;
776+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
774777 } else if (isa<stablehlo::AbsOp, stablehlo::ExpOp, stablehlo::ConvertOp,
775778 stablehlo::CompareOp, stablehlo::TanhOp, stablehlo::LogisticOp,
776779 stablehlo::FloorOp, stablehlo::CeilOp>(op)) {
777780 // elementwise ops that are no-nan if all operands are not nan
778781 recursiveCheck = true ;
782+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
779783 } else if (isa<stablehlo::AddOp, stablehlo::SubtractOp>(op)) {
780-
781784 // If any one of the operands is a Inf, the result is Inf. If both are Inf,
782785 // the result is NaN.
783786 auto lhsFinite =
@@ -790,13 +793,15 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
790793 }
791794
792795 recursiveCheck = true ;
796+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
793797 } else if (isa<stablehlo::SineOp, stablehlo::CosineOp>(op)) {
794798
795799 if (!finiteResultAnalysis->guaranteed (op->getOperand (0 ), rewriter)) {
796800 return State::NOTGUARANTEED;
797801 }
798802
799803 recursiveCheck = true ;
804+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
800805 } else if (auto mulOp = dyn_cast<stablehlo::MulOp>(op)) {
801806 // if lhs is Inf & rhs is 0 or the other way around, mul is going to be NaN
802807
@@ -808,38 +813,19 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
808813 }
809814
810815 recursiveCheck = true ;
811- } else if (isa<mlir::stablehlo::SelectOp>(op)) {
816+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
817+ } else if (isa<stablehlo::SelectOp>(op)) {
818+ recursiveCheck = true ;
819+ operandsToCheck.push_back (op->getOperand (1 ));
820+ operandsToCheck.push_back (op->getOperand (2 ));
821+ } else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
822+ op)) {
812823 recursiveCheck = true ;
824+ operandsToCheck.push_back (op->getOperand (0 ));
813825 }
814826
815827 if (recursiveCheck) {
816- bool allOperandsGuaranteed = true ;
817- for (auto operand : op->getOperands ()) {
818- if (auto TT = dyn_cast<TensorType>(operand.getType ())) {
819- if (TT.getElementType ().isInteger ())
820- continue ;
821- }
822-
823- {
824- auto found = valueCache.find (operand);
825- if (found != valueCache.end ()) {
826- if (found->second ) {
827- continue ;
828- } else {
829- return State::NOTGUARANTEED;
830- }
831- }
832- }
833-
834- localtodo.push_back (operand);
835- allOperandsGuaranteed = false ;
836- }
837-
838- if (allOperandsGuaranteed) {
839- return State::GUARANTEED;
840- } else {
841- return State::PENDING;
842- }
828+ return recursivelyCheckOperands (localtodo, operandsToCheck, true );
843829 } else {
844830 return State::NOTGUARANTEED;
845831 }
@@ -878,54 +864,36 @@ FiniteResultAnalysis::State FiniteResultAnalysis::localGuaranteed(
878864 }
879865
880866 bool recursiveCheck = false ;
867+ SmallVector<Value> operandsToCheck;
881868
882869 if (isa<stablehlo::SliceOp, stablehlo::ConcatenateOp,
883870 stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
884871 stablehlo::TransposeOp>(op)) {
885872 // data movement ops
886873 recursiveCheck = true ;
874+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
887875 } else if (isa<stablehlo::AddOp, stablehlo::SubtractOp, stablehlo::MulOp,
888876 stablehlo::AbsOp, stablehlo::ExpOp, stablehlo::ConvertOp,
889877 stablehlo::CompareOp>(op)) {
890878 // if both finite [but possibly nan], the result is finite, or nan
891-
892879 recursiveCheck = true ;
880+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
893881 } else if (isa<stablehlo::TanhOp, stablehlo::LogisticOp, stablehlo::SineOp,
894882 stablehlo::CosineOp>(op)) {
895883 // guaranteed finite or nan result, always
896884 return State::GUARANTEED;
897- } else if (isa<mlir:: stablehlo::SelectOp>(op)) {
885+ } else if (isa<stablehlo::SelectOp>(op)) {
898886 recursiveCheck = true ;
887+ operandsToCheck.push_back (op->getOperand (1 ));
888+ operandsToCheck.push_back (op->getOperand (2 ));
889+ } else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
890+ op)) {
891+ recursiveCheck = true ;
892+ operandsToCheck.push_back (op->getOperand (0 ));
899893 }
900894
901895 if (recursiveCheck) {
902- bool allOperandsGuaranteed = true ;
903- for (auto operand : op->getOperands ()) {
904- if (auto TT = dyn_cast<TensorType>(operand.getType ())) {
905- if (TT.getElementType ().isInteger ())
906- continue ;
907- }
908-
909- {
910- auto found = valueCache.find (operand);
911- if (found != valueCache.end ()) {
912- if (found->second ) {
913- continue ;
914- } else {
915- return State::NOTGUARANTEED;
916- }
917- }
918- }
919-
920- localtodo.push_back (operand);
921- allOperandsGuaranteed = false ;
922- }
923-
924- if (allOperandsGuaranteed) {
925- return State::GUARANTEED;
926- } else {
927- return State::PENDING;
928- }
896+ return recursivelyCheckOperands (localtodo, operandsToCheck, true );
929897 } else {
930898 return State::NOTGUARANTEED;
931899 }
@@ -1018,45 +986,27 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
1018986 }
1019987
1020988 bool recursiveCheck = false ;
989+ SmallVector<Value> operandsToCheck;
1021990
1022991 if (isa<stablehlo::MinOp, stablehlo::AddOp, stablehlo::MulOp,
1023992 stablehlo::ConcatenateOp, stablehlo::ReshapeOp,
1024993 stablehlo::TransposeOp, stablehlo::SliceOp,
1025- stablehlo::DynamicUpdateSliceOp, stablehlo:: BroadcastInDimOp>(op)) {
994+ stablehlo::BroadcastInDimOp>(op)) {
1026995 // All non-negative operations that produce a non-negative result
1027996 recursiveCheck = true ;
1028- } else if (isa<mlir::stablehlo::SelectOp>(op)) {
997+ operandsToCheck.append (op->getOperands ().begin (), op->getOperands ().end ());
998+ } else if (isa<stablehlo::SelectOp>(op)) {
999+ recursiveCheck = true ;
1000+ operandsToCheck.push_back (op->getOperand (1 ));
1001+ operandsToCheck.push_back (op->getOperand (2 ));
1002+ } else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
1003+ op)) {
10291004 recursiveCheck = true ;
1005+ operandsToCheck.push_back (op->getOperand (0 ));
10301006 }
10311007
10321008 if (recursiveCheck) {
1033- bool allOperandsGuaranteed = true ;
1034- size_t idx = 0 ;
1035- for (auto operand : op->getOperands ()) {
1036- if (idx == 0 && isa<mlir::stablehlo::SelectOp>(op))
1037- continue ;
1038- idx++;
1039-
1040- {
1041- auto found = valueCache.find (operand);
1042- if (found != valueCache.end ()) {
1043- if (found->second ) {
1044- continue ;
1045- } else {
1046- return State::NOTGUARANTEED;
1047- }
1048- }
1049- }
1050-
1051- localtodo.push_back (operand);
1052- allOperandsGuaranteed = false ;
1053- }
1054-
1055- if (allOperandsGuaranteed) {
1056- return State::GUARANTEED;
1057- } else {
1058- return State::PENDING;
1059- }
1009+ return recursivelyCheckOperands (localtodo, operandsToCheck, false );
10601010 } else {
10611011 return State::NOTGUARANTEED;
10621012 }
0 commit comments