diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 1959d5f3a5899..e04621f71f9a7 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr &expr) { return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0; } +// Checks whether the symbol on the LHS is present in the RHS expression. +bool CheckForSymbolMatch(const Expr *lhs, const Expr *rhs); + +namespace operation { + +enum class Operator { + Unknown, + Add, + And, + Associated, + Call, + Constant, + Convert, + Div, + Eq, + Eqv, + False, + Ge, + Gt, + Identity, + Intrinsic, + Le, + Lt, + Max, + Min, + Mul, + Ne, + Neqv, + Not, + Or, + Pow, + Resize, // Convert within the same TypeCategory + Sub, + True, +}; + +std::string ToString(Operator op); + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + switch (op.derived().logicalOperator) { + case common::LogicalOperator::And: + return Operator::And; + case common::LogicalOperator::Or: + return Operator::Or; + case common::LogicalOperator::Eqv: + return Operator::Eqv; + case common::LogicalOperator::Neqv: + return Operator::Neqv; + case common::LogicalOperator::Not: + return Operator::Not; + } + return Operator::Unknown; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + switch (op.derived().opr) { + case common::RelationalOperator::LT: + return Operator::Lt; + case common::RelationalOperator::LE: + return Operator::Le; + case common::RelationalOperator::EQ: + return Operator::Eq; + case common::RelationalOperator::NE: + return Operator::Ne; + case common::RelationalOperator::GE: + return Operator::Ge; + case common::RelationalOperator::GT: + return Operator::Gt; + } + return Operator::Unknown; +} + +template +Operator OperationCode(const evaluate::Operation, Ts...> &op) { + return Operator::Add; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + return Operator::Sub; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + return Operator::Mul; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + return Operator::Div; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + return Operator::Pow; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + return Operator::Pow; +} + +template +Operator OperationCode( + const evaluate::Operation, Ts...> &op) { + if constexpr (C == T::category) { + return Operator::Resize; + } else { + return Operator::Convert; + } +} + +template Operator OperationCode(const evaluate::Constant &x) { + return Operator::Constant; +} + +template Operator OperationCode(const T &) { + return Operator::Unknown; +} + +Operator OperationCode(const evaluate::ProcedureDesignator &proc); + +} // namespace operation + +// Return information about the top-level operation (ignoring parentheses): +// the operation code and the list of arguments. +std::pair>> +GetTopLevelOperation(const Expr &expr); + +// Check if expr is same as x, or a sequence of Convert operations on x. +bool IsSameOrConvertOf(const Expr &expr, const Expr &x); + +// Strip away any top-level Convert operations (if any exist) and return +// the input value. A ComplexConstructor(x, 0) is also considered as a +// convert operation. +// If the input is not Operation, Designator, FunctionRef or Constant, +// it returns std::nullopt. +std::optional> GetConvertInput(const Expr &x); + } // namespace Fortran::evaluate namespace Fortran::semantics { diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h index 69375a83dec25..f3cfa9b99fb4d 100644 --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring); // Check for ambiguous USE associations bool HadUseError(SemanticsContext &, SourceName at, const Symbol *); -// Checks whether the symbol on the LHS is present in the RHS expression. -bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs); - -namespace operation { - -enum class Operator { - Unknown, - Add, - And, - Associated, - Call, - Constant, - Convert, - Div, - Eq, - Eqv, - False, - Ge, - Gt, - Identity, - Intrinsic, - Le, - Lt, - Max, - Min, - Mul, - Ne, - Neqv, - Not, - Or, - Pow, - Resize, // Convert within the same TypeCategory - Sub, - True, -}; - -std::string ToString(Operator op); - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - switch (op.derived().logicalOperator) { - case common::LogicalOperator::And: - return Operator::And; - case common::LogicalOperator::Or: - return Operator::Or; - case common::LogicalOperator::Eqv: - return Operator::Eqv; - case common::LogicalOperator::Neqv: - return Operator::Neqv; - case common::LogicalOperator::Not: - return Operator::Not; - } - return Operator::Unknown; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - switch (op.derived().opr) { - case common::RelationalOperator::LT: - return Operator::Lt; - case common::RelationalOperator::LE: - return Operator::Le; - case common::RelationalOperator::EQ: - return Operator::Eq; - case common::RelationalOperator::NE: - return Operator::Ne; - case common::RelationalOperator::GE: - return Operator::Ge; - case common::RelationalOperator::GT: - return Operator::Gt; - } - return Operator::Unknown; -} - -template -Operator OperationCode(const evaluate::Operation, Ts...> &op) { - return Operator::Add; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - return Operator::Sub; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - return Operator::Mul; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - return Operator::Div; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - return Operator::Pow; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - return Operator::Pow; -} - -template -Operator OperationCode( - const evaluate::Operation, Ts...> &op) { - if constexpr (C == T::category) { - return Operator::Resize; - } else { - return Operator::Convert; - } -} - -template // -Operator OperationCode(const evaluate::Constant &x) { - return Operator::Constant; -} - -template // -Operator OperationCode(const T &) { - return Operator::Unknown; -} - -Operator OperationCode(const evaluate::ProcedureDesignator &proc); - -} // namespace operation - -/// Return information about the top-level operation (ignoring parentheses): -/// the operation code and the list of arguments. -std::pair> GetTopLevelOperation( - const SomeExpr &expr); - -/// Check if expr is same as x, or a sequence of Convert operations on x. -bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x); - -/// Strip away any top-level Convert operations (if any exist) and return -/// the input value. A ComplexConstructor(x, 0) is also considered as a -/// convert operation. -/// If the input is not Operation, Designator, FunctionRef or Constant, -/// it returns std::nullopt. -MaybeExpr GetConvertInput(const SomeExpr &x); } // namespace Fortran::semantics #endif // FORTRAN_SEMANTICS_TOOLS_H_ diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index 222c32a9c332e..68838564f87ba 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -13,6 +13,7 @@ #include "flang/Evaluate/traverse.h" #include "flang/Parser/message.h" #include "flang/Semantics/tools.h" +#include "llvm/ADT/StringSwitch.h" #include #include @@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages, } } +bool CheckForSymbolMatch(const Expr *lhs, const Expr *rhs) { + if (lhs && rhs) { + if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) { + const Symbol &first{*lhsSymbols.front()}; + for (const Symbol &symbol : GetSymbolVector(*rhs)) { + if (first == symbol) { + return true; + } + } + } + } + return false; +} + +namespace operation { +template Expr AsSomeExpr(const T &x) { + return AsGenericExpr(common::Clone(x)); +} + +template +struct ArgumentExtractor + : public Traverse, + std::pair>>, false> { + using Arguments = std::vector>; + using Result = std::pair; + using Base = + Traverse, Result, false>; + static constexpr auto IgnoreResizes{IgnoreResizingConverts}; + static constexpr auto Logical{common::TypeCategory::Logical}; + ArgumentExtractor() : Base(*this) {} + + Result Default() const { return {}; } + + using Base::operator(); + + template + Result operator()(const Constant> &x) const { + if (const auto &val{x.GetScalarValue()}) { + return val->IsTrue() + ? std::make_pair(operation::Operator::True, Arguments{}) + : std::make_pair(operation::Operator::False, Arguments{}); + } + return Default(); + } + + template Result operator()(const FunctionRef &x) const { + Result result{operation::OperationCode(x.proc()), {}}; + for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) { + if (auto *e{x.UnwrapArgExpr(i)}) { + result.second.push_back(*e); + } + } + return result; + } + + template + Result operator()(const Operation &x) const { + if constexpr (std::is_same_v>) { + // Ignore top-level parentheses. + return (*this)(x.template operand<0>()); + } + if constexpr (IgnoreResizes && std::is_same_v>) { + // Ignore conversions within the same category. + // Atomic operations on int(kind=1) may be implicitly widened + // to int(kind=4) for example. + return (*this)(x.template operand<0>()); + } else { + return std::make_pair(operation::OperationCode(x), + OperationArgs(x, std::index_sequence_for{})); + } + } + + template Result operator()(const Designator &x) const { + return {operation::Operator::Identity, {AsSomeExpr(x)}}; + } + + template Result operator()(const Constant &x) const { + return {operation::Operator::Identity, {AsSomeExpr(x)}}; + } + + template + Result Combine(Result &&result, Rs &&...results) const { + // There shouldn't be any combining needed, since we're stopping the + // traversal at the top-level operation, but implement one that picks + // the first non-empty result. + if constexpr (sizeof...(Rs) == 0) { + return std::move(result); + } else { + if (!result.second.empty()) { + return std::move(result); + } else { + return Combine(std::move(results)...); + } + } + } + +private: + template + Arguments OperationArgs( + const Operation &x, std::index_sequence) const { + return Arguments{Expr(x.template operand())...}; + } +}; +} // namespace operation + +std::string operation::ToString(operation::Operator op) { + switch (op) { + case Operator::Unknown: + return "??"; + case Operator::Add: + return "+"; + case Operator::And: + return "AND"; + case Operator::Associated: + return "ASSOCIATED"; + case Operator::Call: + return "function-call"; + case Operator::Constant: + return "constant"; + case Operator::Convert: + return "type-conversion"; + case Operator::Div: + return "/"; + case Operator::Eq: + return "=="; + case Operator::Eqv: + return "EQV"; + case Operator::False: + return ".FALSE."; + case Operator::Ge: + return ">="; + case Operator::Gt: + return ">"; + case Operator::Identity: + return "identity"; + case Operator::Intrinsic: + return "intrinsic"; + case Operator::Le: + return "<="; + case Operator::Lt: + return "<"; + case Operator::Max: + return "MAX"; + case Operator::Min: + return "MIN"; + case Operator::Mul: + return "*"; + case Operator::Ne: + return "/="; + case Operator::Neqv: + return "NEQV/EOR"; + case Operator::Not: + return "NOT"; + case Operator::Or: + return "OR"; + case Operator::Pow: + return "**"; + case Operator::Resize: + return "resize"; + case Operator::Sub: + return "-"; + case Operator::True: + return ".TRUE."; + } + llvm_unreachable("Unhandler operator"); +} + +operation::Operator operation::OperationCode(const ProcedureDesignator &proc) { + Operator code{llvm::StringSwitch(proc.GetName()) + .Case("associated", Operator::Associated) + .Case("min", Operator::Min) + .Case("max", Operator::Max) + .Case("iand", Operator::And) + .Case("ior", Operator::Or) + .Case("ieor", Operator::Neqv) + .Default(Operator::Call)}; + if (code == Operator::Call && proc.GetSpecificIntrinsic()) { + return Operator::Intrinsic; + } + return code; +} + +std::pair>> +GetTopLevelOperation(const Expr &expr) { + return operation::ArgumentExtractor{}(expr); +} + +namespace operation { +struct ConvertCollector + : public Traverse>, std::vector>, + false> { + using Result = + std::pair>, std::vector>; + using Base = Traverse; + ConvertCollector() : Base(*this) {} + + Result Default() const { return {}; } + + using Base::operator(); + + template Result operator()(const Designator &x) const { + return {AsSomeExpr(x), {}}; + } + + template Result operator()(const FunctionRef &x) const { + return {AsSomeExpr(x), {}}; + } + + template Result operator()(const Constant &x) const { + return {AsSomeExpr(x), {}}; + } + + template + Result operator()(const Operation &x) const { + if constexpr (std::is_same_v>) { + // Ignore parentheses. + return (*this)(x.template operand<0>()); + } else if constexpr (is_convert_v) { + // Convert should always have a typed result, so it should be safe to + // dereference x.GetType(). + return Combine( + {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>())); + } else if constexpr (is_complex_constructor_v) { + // This is a conversion iff the imaginary operand is 0. + if (IsZero(x.template operand<1>())) { + return Combine( + {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>())); + } else { + return {AsSomeExpr(x.derived()), {}}; + } + } else { + return {AsSomeExpr(x.derived()), {}}; + } + } + + template + Result Combine(Result &&result, Rs &&...results) const { + Result v(std::move(result)); + auto setValue{[](std::optional> &x, + std::optional> &&y) { + assert((!x.has_value() || !y.has_value()) && "Multiple designators"); + if (!x.has_value()) { + x = std::move(y); + } + }}; + auto moveAppend{[](auto &accum, auto &&other) { + for (auto &&s : other) { + accum.push_back(std::move(s)); + } + }}; + (setValue(v.first, std::move(results).first), ...); + (moveAppend(v.second, std::move(results).second), ...); + return v; + } + +private: + template static bool IsZero(const A &x) { return false; } + template static bool IsZero(const Expr &x) { + return common::visit([](auto &&s) { return IsZero(s); }, x.u); + } + template static bool IsZero(const Constant &x) { + if (auto &&maybeScalar{x.GetScalarValue()}) { + return maybeScalar->IsZero(); + } else { + return false; + } + } + + template struct is_convert { + static constexpr bool value{false}; + }; + template + struct is_convert> { + static constexpr bool value{true}; + }; + template struct is_convert> { + // Conversion from complex to real. + static constexpr bool value{true}; + }; + template + static constexpr bool is_convert_v{is_convert::value}; + + template struct is_complex_constructor { + static constexpr bool value{false}; + }; + template struct is_complex_constructor> { + static constexpr bool value{true}; + }; + template + static constexpr bool is_complex_constructor_v{ + is_complex_constructor::value}; +}; +} // namespace operation + +std::optional> GetConvertInput(const Expr &x) { + // This returns Expr{x} when x is a designator/functionref/constant. + return operation::ConvertCollector{}(x).first; +} + +bool IsSameOrConvertOf(const Expr &expr, const Expr &x) { + // Check if expr is same as x, or a sequence of Convert operations on x. + if (expr == x) { + return true; + } else if (auto maybe{GetConvertInput(expr)}) { + return *maybe == x; + } else { + return false; + } +} } // namespace Fortran::evaluate namespace Fortran::semantics { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 69e9c53baa740..3ef3330cba2d6 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter, mlir::Block &block = atomicCaptureOp->getRegion(0).back(); firOpBuilder.setInsertionPointToStart(&block); if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) { - if (Fortran::semantics::CheckForSymbolMatch( + if (Fortran::evaluate::CheckForSymbolMatch( Fortran::semantics::GetExpr(stmt2Var), Fortran::semantics::GetExpr(stmt2Expr))) { // Atomic capture construct is of the form [capture-stmt, update-stmt] diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 82673f0948a5b..0acfd5b0a2534 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc, mlir::Type atomType = fir::unwrapRefType(atomAddr.getType()); // This must exist by now. - SomeExpr input = *semantics::GetConvertInput(assign.rhs); - std::vector args{semantics::GetTopLevelOperation(input).second}; + SomeExpr input = *Fortran::evaluate::GetConvertInput(assign.rhs); + std::vector args{ + Fortran::evaluate::GetTopLevelOperation(input).second}; assert(!args.empty() && "Update operation without arguments"); for (auto &arg : args) { - if (!semantics::IsSameOrConvertOf(arg, atom)) { + if (!Fortran::evaluate::IsSameOrConvertOf(arg, atom)) { mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc)); overrides.try_emplace(&arg, val); } diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index 58d28dce7094a..47bd4e8ffd43c 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -12,6 +12,7 @@ #include "flang/Evaluate/check-expression.h" #include "flang/Evaluate/expression.h" #include "flang/Evaluate/shape.h" +#include "flang/Evaluate/tools.h" #include "flang/Evaluate/type.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" @@ -2962,6 +2963,8 @@ static bool IsPointerAssignment(const evaluate::Assignment &x) { std::holds_alternative(x.u); } +namespace operation = Fortran::evaluate::operation; + static bool IsCheckForAssociated(const SomeExpr &cond) { return GetTopLevelOperation(cond).first == operation::Operator::Associated; } diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index bf520d04a50cc..d053179448c00 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -17,7 +17,6 @@ #include "flang/Semantics/tools.h" #include "flang/Semantics/type.h" #include "flang/Support/Fortran.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -1789,332 +1788,4 @@ bool HadUseError( } } -bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs) { - if (lhs && rhs) { - if (SymbolVector lhsSymbols{evaluate::GetSymbolVector(*lhs)}; - !lhsSymbols.empty()) { - const Symbol &first{*lhsSymbols.front()}; - for (const Symbol &symbol : evaluate::GetSymbolVector(*rhs)) { - if (first == symbol) { - return true; - } - } - } - } - return false; -} - -namespace operation { -template // -SomeExpr asSomeExpr(const T &x) { - auto copy{x}; - return AsGenericExpr(std::move(copy)); -} - -template // -struct ArgumentExtractor - : public evaluate::Traverse, - std::pair>, false> { - using Arguments = std::vector; - using Result = std::pair; - using Base = evaluate::Traverse, - Result, false>; - static constexpr auto IgnoreResizes = IgnoreResizingConverts; - static constexpr auto Logical = common::TypeCategory::Logical; - ArgumentExtractor() : Base(*this) {} - - Result Default() const { return {}; } - - using Base::operator(); - - template // - Result operator()( - const evaluate::Constant> &x) const { - if (const auto &val{x.GetScalarValue()}) { - return val->IsTrue() - ? std::make_pair(operation::Operator::True, Arguments{}) - : std::make_pair(operation::Operator::False, Arguments{}); - } - return Default(); - } - - template // - Result operator()(const evaluate::FunctionRef &x) const { - Result result{operation::OperationCode(x.proc()), {}}; - for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) { - if (auto *e{x.UnwrapArgExpr(i)}) { - result.second.push_back(*e); - } - } - return result; - } - - template - Result operator()(const evaluate::Operation &x) const { - if constexpr (std::is_same_v>) { - // Ignore top-level parentheses. - return (*this)(x.template operand<0>()); - } - if constexpr (IgnoreResizes && - std::is_same_v>) { - // Ignore conversions within the same category. - // Atomic operations on int(kind=1) may be implicitly widened - // to int(kind=4) for example. - return (*this)(x.template operand<0>()); - } else { - return std::make_pair(operation::OperationCode(x), - OperationArgs(x, std::index_sequence_for{})); - } - } - - template // - Result operator()(const evaluate::Designator &x) const { - return {operation::Operator::Identity, {asSomeExpr(x)}}; - } - - template // - Result operator()(const evaluate::Constant &x) const { - return {operation::Operator::Identity, {asSomeExpr(x)}}; - } - - template // - Result Combine(Result &&result, Rs &&...results) const { - // There shouldn't be any combining needed, since we're stopping the - // traversal at the top-level operation, but implement one that picks - // the first non-empty result. - if constexpr (sizeof...(Rs) == 0) { - return std::move(result); - } else { - if (!result.second.empty()) { - return std::move(result); - } else { - return Combine(std::move(results)...); - } - } - } - -private: - template - Arguments OperationArgs(const evaluate::Operation &x, - std::index_sequence) const { - return Arguments{SomeExpr(x.template operand())...}; - } -}; -} // namespace operation - -std::string operation::ToString(operation::Operator op) { - switch (op) { - case Operator::Unknown: - return "??"; - case Operator::Add: - return "+"; - case Operator::And: - return "AND"; - case Operator::Associated: - return "ASSOCIATED"; - case Operator::Call: - return "function-call"; - case Operator::Constant: - return "constant"; - case Operator::Convert: - return "type-conversion"; - case Operator::Div: - return "/"; - case Operator::Eq: - return "=="; - case Operator::Eqv: - return "EQV"; - case Operator::False: - return ".FALSE."; - case Operator::Ge: - return ">="; - case Operator::Gt: - return ">"; - case Operator::Identity: - return "identity"; - case Operator::Intrinsic: - return "intrinsic"; - case Operator::Le: - return "<="; - case Operator::Lt: - return "<"; - case Operator::Max: - return "MAX"; - case Operator::Min: - return "MIN"; - case Operator::Mul: - return "*"; - case Operator::Ne: - return "/="; - case Operator::Neqv: - return "NEQV/EOR"; - case Operator::Not: - return "NOT"; - case Operator::Or: - return "OR"; - case Operator::Pow: - return "**"; - case Operator::Resize: - return "resize"; - case Operator::Sub: - return "-"; - case Operator::True: - return ".TRUE."; - } - llvm_unreachable("Unhandler operator"); -} - -operation::Operator operation::OperationCode( - const evaluate::ProcedureDesignator &proc) { - Operator code = llvm::StringSwitch(proc.GetName()) - .Case("associated", Operator::Associated) - .Case("min", Operator::Min) - .Case("max", Operator::Max) - .Case("iand", Operator::And) - .Case("ior", Operator::Or) - .Case("ieor", Operator::Neqv) - .Default(Operator::Call); - if (code == Operator::Call && proc.GetSpecificIntrinsic()) { - return Operator::Intrinsic; - } - return code; -} - -std::pair> GetTopLevelOperation( - const SomeExpr &expr) { - return operation::ArgumentExtractor{}(expr); -} - -namespace operation { -struct ConvertCollector - : public evaluate::Traverse>, false> { - using Result = std::pair>; - using Base = evaluate::Traverse; - ConvertCollector() : Base(*this) {} - - Result Default() const { return {}; } - - using Base::operator(); - - template // - Result operator()(const evaluate::Designator &x) const { - return {asSomeExpr(x), {}}; - } - - template // - Result operator()(const evaluate::FunctionRef &x) const { - return {asSomeExpr(x), {}}; - } - - template // - Result operator()(const evaluate::Constant &x) const { - return {asSomeExpr(x), {}}; - } - - template - Result operator()(const evaluate::Operation &x) const { - if constexpr (std::is_same_v>) { - // Ignore parentheses. - return (*this)(x.template operand<0>()); - } else if constexpr (is_convert_v) { - // Convert should always have a typed result, so it should be safe to - // dereference x.GetType(). - return Combine( - {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>())); - } else if constexpr (is_complex_constructor_v) { - // This is a conversion iff the imaginary operand is 0. - if (IsZero(x.template operand<1>())) { - return Combine( - {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>())); - } else { - return {asSomeExpr(x.derived()), {}}; - } - } else { - return {asSomeExpr(x.derived()), {}}; - } - } - - template // - Result Combine(Result &&result, Rs &&...results) const { - Result v(std::move(result)); - auto setValue{[](MaybeExpr &x, MaybeExpr &&y) { - assert((!x.has_value() || !y.has_value()) && "Multiple designators"); - if (!x.has_value()) { - x = std::move(y); - } - }}; - auto moveAppend{[](auto &accum, auto &&other) { - for (auto &&s : other) { - accum.push_back(std::move(s)); - } - }}; - (setValue(v.first, std::move(results).first), ...); - (moveAppend(v.second, std::move(results).second), ...); - return v; - } - -private: - template // - static bool IsZero(const T &x) { - return false; - } - template // - static bool IsZero(const evaluate::Expr &x) { - return common::visit([](auto &&s) { return IsZero(s); }, x.u); - } - template // - static bool IsZero(const evaluate::Constant &x) { - if (auto &&maybeScalar{x.GetScalarValue()}) { - return maybeScalar->IsZero(); - } else { - return false; - } - } - - template // - struct is_convert { - static constexpr bool value{false}; - }; - template // - struct is_convert> { - static constexpr bool value{true}; - }; - template // - struct is_convert> { - // Conversion from complex to real. - static constexpr bool value{true}; - }; - template // - static constexpr bool is_convert_v = is_convert::value; - - template // - struct is_complex_constructor { - static constexpr bool value{false}; - }; - template // - struct is_complex_constructor> { - static constexpr bool value{true}; - }; - template // - static constexpr bool is_complex_constructor_v = - is_complex_constructor::value; -}; -} // namespace operation - -MaybeExpr GetConvertInput(const SomeExpr &x) { - // This returns SomeExpr(x) when x is a designator/functionref/constant. - return operation::ConvertCollector{}(x).first; -} - -bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) { - // Check if expr is same as x, or a sequence of Convert operations on x. - if (expr == x) { - return true; - } else if (auto maybe{GetConvertInput(expr)}) { - return *maybe == x; - } else { - return false; - } -} } // namespace Fortran::semantics \ No newline at end of file