Skip to content

Commit e3e6729

Browse files
authored
Merge pull request #11481 from ethereum/unify-resolve
Unify function call resolve function used in Analysis & Yul CodeGen
2 parents ad3bc71 + 6a0313c commit e3e6729

File tree

5 files changed

+60
-101
lines changed

5 files changed

+60
-101
lines changed

libsolidity/analysis/ControlFlowRevertPruner.cpp

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -54,58 +54,6 @@ void ControlFlowRevertPruner::run()
5454
modifyFunctionFlows();
5555
}
5656

57-
FunctionDefinition const* ControlFlowRevertPruner::resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _contract)
58-
{
59-
auto result = m_resolveCache.find({&_functionCall, _contract});
60-
if (result != m_resolveCache.end())
61-
return result->second;
62-
63-
auto const& functionType = dynamic_cast<FunctionType const&>(
64-
*_functionCall.expression().annotation().type
65-
);
66-
67-
if (!functionType.hasDeclaration())
68-
return nullptr;
69-
70-
auto const& unresolvedFunctionDefinition =
71-
dynamic_cast<FunctionDefinition const&>(functionType.declaration());
72-
73-
FunctionDefinition const* returnFunctionDef = &unresolvedFunctionDefinition;
74-
75-
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
76-
{
77-
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
78-
{
79-
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
80-
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
81-
{
82-
solAssert(contractType->isSuper(), "");
83-
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_contract);
84-
85-
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(
86-
*_contract,
87-
superContract
88-
);
89-
}
90-
}
91-
else
92-
{
93-
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
94-
returnFunctionDef = &unresolvedFunctionDefinition;
95-
}
96-
}
97-
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
98-
{
99-
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
100-
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(*_contract);
101-
}
102-
103-
if (returnFunctionDef && !returnFunctionDef->isImplemented())
104-
returnFunctionDef = nullptr;
105-
106-
return m_resolveCache[{&_functionCall, _contract}] = returnFunctionDef;
107-
}
108-
10957
void ControlFlowRevertPruner::findRevertStates()
11058
{
11159
std::set<CFG::FunctionContractTuple> pendingFunctions = keys(m_functions);
@@ -130,9 +78,9 @@ void ControlFlowRevertPruner::findRevertStates()
13078

13179
for (auto const* functionCall: _node->functionCalls)
13280
{
133-
auto const* resolvedFunction = resolveCall(*functionCall, item.contract);
81+
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract);
13482

135-
if (resolvedFunction == nullptr)
83+
if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
13684
continue;
13785

13886
switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction}))
@@ -180,9 +128,9 @@ void ControlFlowRevertPruner::modifyFunctionFlows()
180128
[&](CFGNode* _node, auto&& _addChild) {
181129
for (auto const* functionCall: _node->functionCalls)
182130
{
183-
auto const* resolvedFunction = resolveCall(*functionCall, item.first.contract);
131+
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract);
184132

185-
if (resolvedFunction == nullptr)
133+
if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
186134
continue;
187135

188136
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
@@ -223,7 +171,11 @@ void ControlFlowRevertPruner::collectCalls(FunctionDefinition const& _function,
223171
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
224172
[&](CFGNode* _node, auto&& _addChild) {
225173
for (auto const* functionCall: _node->functionCalls)
226-
m_calledBy[resolveCall(*functionCall, _mostDerivedContract)].insert(pair);
174+
{
175+
auto const* funcDef = ASTNode::resolveFunctionCall(*functionCall, _mostDerivedContract);
176+
if (funcDef && funcDef->isImplemented())
177+
m_calledBy[funcDef].insert(pair);
178+
}
227179

228180
for (CFGNode* exit: _node->exits)
229181
_addChild(exit);

libsolidity/analysis/ControlFlowRevertPruner.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,6 @@ class ControlFlowRevertPruner
4545
Unknown,
4646
};
4747

48-
/// Simple attempt at resolving a function call
49-
/// Does not aim to be able to resolve all calls, only used for variable
50-
/// assignment tracking and revert behavior.
51-
/// @param _functionCall the function call to analyse
52-
/// @param _mostDerivedContract most derived contract
53-
/// @returns function definition to which the call resolved or nullptr if no
54-
/// definition was found.
55-
FunctionDefinition const* resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
56-
5748
/// Identify revert states of all function flows
5849
void findRevertStates();
5950

libsolidity/ast/AST.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@ Declaration const* ASTNode::referencedDeclaration(Expression const& _expression)
5757
return nullptr;
5858
}
5959

60+
FunctionDefinition const* ASTNode::resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract)
61+
{
62+
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
63+
ASTNode::referencedDeclaration(_functionCall.expression())
64+
);
65+
66+
if (!functionDef)
67+
return nullptr;
68+
69+
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
70+
{
71+
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
72+
{
73+
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
74+
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
75+
{
76+
solAssert(_mostDerivedContract, "");
77+
solAssert(contractType->isSuper(), "");
78+
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_mostDerivedContract);
79+
80+
return &functionDef->resolveVirtual(
81+
*_mostDerivedContract,
82+
superContract
83+
);
84+
}
85+
}
86+
else
87+
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
88+
}
89+
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
90+
{
91+
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
92+
if (functionDef->virtualSemantics())
93+
{
94+
solAssert(_mostDerivedContract, "");
95+
return &functionDef->resolveVirtual(*_mostDerivedContract);
96+
}
97+
}
98+
else
99+
solAssert(false, "");
100+
101+
return functionDef;
102+
}
103+
60104
ASTAnnotation& ASTNode::annotation() const
61105
{
62106
if (!m_annotation)

libsolidity/ast/AST.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class ASTNode
104104
/// Extracts the referenced declaration from all nodes whose annotations support
105105
/// `referencedDeclaration`.
106106
static Declaration const* referencedDeclaration(Expression const& _expression);
107+
/// Performs potential super or virtual lookup for a function call based on the most derived contract.
108+
static FunctionDefinition const* resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
107109

108110
/// Returns the source code location of this node.
109111
SourceLocation const& location() const { return m_location; }

libsolidity/codegen/ir/IRGeneratorForStatements.cpp

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -884,48 +884,14 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
884884
return;
885885
}
886886

887-
auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression());
888-
889887
switch (functionType->kind())
890888
{
891889
case FunctionType::Kind::Declaration:
892890
solAssert(false, "Attempted to generate code for calling a function definition.");
893891
break;
894892
case FunctionType::Kind::Internal:
895893
{
896-
auto identifier = dynamic_cast<Identifier const*>(&_functionCall.expression());
897-
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
898-
ASTNode::referencedDeclaration(_functionCall.expression())
899-
);
900-
901-
if (functionDef)
902-
{
903-
solAssert(memberAccess || identifier, "");
904-
solAssert(functionType->declaration() == *functionDef, "");
905-
906-
if (identifier)
907-
{
908-
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
909-
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract());
910-
}
911-
else if (auto typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
912-
if (
913-
auto contractType = dynamic_cast<ContractType const*>(typeType->actualType());
914-
contractType->isSuper()
915-
)
916-
{
917-
ContractDefinition const* super = contractType->contractDefinition().superContract(m_context.mostDerivedContract());
918-
solAssert(super, "Super contract not available.");
919-
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Super, "");
920-
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract(), super);
921-
}
922-
923-
solAssert(functionDef && functionDef->isImplemented(), "");
924-
solAssert(
925-
functionDef->parameters().size() == arguments.size() + (functionType->bound() ? 1 : 0),
926-
""
927-
);
928-
}
894+
FunctionDefinition const* functionDef = ASTNode::resolveFunctionCall(_functionCall, &m_context.mostDerivedContract());
929895

930896
solAssert(!functionType->takesArbitraryParameters(), "");
931897

@@ -937,11 +903,15 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
937903
args += convert(*arguments[i], *parameterTypes[i]).stackSlots();
938904

939905
if (functionDef)
906+
{
907+
solAssert(functionDef->isImplemented(), "");
908+
940909
define(_functionCall) <<
941910
m_context.enqueueFunctionForCodeGeneration(*functionDef) <<
942911
"(" <<
943912
joinHumanReadable(args) <<
944913
")\n";
914+
}
945915
else
946916
{
947917
YulArity arity = YulArity::fromType(*functionType);

0 commit comments

Comments
 (0)