Skip to content

Commit c8cb30f

Browse files
committed
Swift: refactor StmtVisitor to use translations
Also make `visit` in `SwiftDispatcher` work on `const` pointers. Also, fixed a bug where the guard of a `CaseLabelItem` was not being extracted, hence the test updates.
1 parent faf1029 commit c8cb30f

File tree

10 files changed

+202
-209
lines changed

10 files changed

+202
-209
lines changed

swift/extractor/infra/SwiftDispatcher.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ class SwiftDispatcher {
211211
template <typename Iterable>
212212
auto fetchRepeatedLabels(Iterable&& arg) {
213213
std::vector<decltype(fetchLabel(*arg.begin()))> ret;
214-
ret.reserve(arg.size());
214+
if constexpr (HasSize<Iterable>::value) {
215+
ret.reserve(arg.size());
216+
}
215217
for (auto&& e : arg) {
216218
ret.push_back(fetchLabel(e));
217219
}
@@ -262,6 +264,12 @@ class SwiftDispatcher {
262264
}
263265

264266
private:
267+
template <typename T, typename = void>
268+
struct HasSize : std::false_type {};
269+
270+
template <typename T>
271+
struct HasSize<T, decltype(std::declval<T>().size(), void())> : std::true_type {};
272+
265273
void attachLocation(swift::SourceLoc start,
266274
swift::SourceLoc end,
267275
TrapLabel<LocatableTag> locatableLabel) {
@@ -323,14 +331,14 @@ class SwiftDispatcher {
323331
// TODO: for const correctness these should consistently be `const` (and maybe const references
324332
// as we don't expect `nullptr` here. However `swift::ASTVisitor` and `swift::TypeVisitor` do not
325333
// accept const pointers
326-
virtual void visit(swift::Decl* decl) = 0;
327-
virtual void visit(swift::Stmt* stmt) = 0;
334+
virtual void visit(const swift::Decl* decl) = 0;
335+
virtual void visit(const swift::Stmt* stmt) = 0;
328336
virtual void visit(const swift::StmtCondition* cond) = 0;
329337
virtual void visit(const swift::StmtConditionElement* cond) = 0;
330-
virtual void visit(swift::CaseLabelItem* item) = 0;
331-
virtual void visit(swift::Expr* expr) = 0;
338+
virtual void visit(const swift::CaseLabelItem* item) = 0;
339+
virtual void visit(const swift::Expr* expr) = 0;
332340
virtual void visit(const swift::Pattern* pattern) = 0;
333-
virtual void visit(swift::TypeRepr* typeRepr, swift::Type type) = 0;
341+
virtual void visit(const swift::TypeRepr* typeRepr, swift::Type type) = 0;
334342
virtual void visit(swift::TypeBase* type) = 0;
335343

336344
void visit(const std::filesystem::path& file) {

swift/extractor/visitors/ExprVisitor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ void ExprVisitor::emitAccessorSemantics(T* ast, Label label) {
2323
}
2424
}
2525

26-
void ExprVisitor::visit(swift::Expr* expr) {
27-
swift::ExprVisitor<ExprVisitor, void>::visit(expr);
26+
void ExprVisitor::visit(const swift::Expr* expr) {
27+
AstVisitorBase<ExprVisitor>::visit(expr);
2828
auto label = dispatcher_.fetchLabel(expr);
2929
if (auto type = expr->getType()) {
3030
dispatcher_.emit(ExprTypesTrap{label, dispatcher_.fetchLabel(type)});

swift/extractor/visitors/ExprVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class ExprVisitor : public AstVisitorBase<ExprVisitor> {
99
public:
1010
using AstVisitorBase<ExprVisitor>::AstVisitorBase;
1111

12-
void visit(swift::Expr* expr);
12+
void visit(const swift::Expr* expr);
1313
void visitIntegerLiteralExpr(swift::IntegerLiteralExpr* expr);
1414
void visitFloatLiteralExpr(swift::FloatLiteralExpr* expr);
1515
void visitBooleanLiteralExpr(swift::BooleanLiteralExpr* expr);

swift/extractor/visitors/PatternVisitor.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,6 @@ namespace codeql {
88
class PatternVisitor : public AstVisitorBase<PatternVisitor> {
99
public:
1010
using AstVisitorBase<PatternVisitor>::AstVisitorBase;
11-
using AstVisitorBase<PatternVisitor>::visit;
12-
13-
// TODO
14-
// swift does not provide const visitors, for the moment we const_cast and promise not to
15-
// change the entities. When all visitors have been turned to translators, we can ditch
16-
// swift::ASTVisitor and roll out our own const-correct TranslatorBase class
17-
void visit(const swift::Pattern* pattern) { visit(const_cast<swift::Pattern*>(pattern)); }
1811

1912
codeql::NamedPattern translateNamedPattern(const swift::NamedPattern& pattern);
2013
codeql::TypedPattern translateTypedPattern(const swift::TypedPattern& pattern);

swift/extractor/visitors/StmtVisitor.cpp

Lines changed: 107 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33
namespace codeql {
44

5-
void StmtVisitor::visitLabeledStmt(swift::LabeledStmt* stmt) {
6-
auto label = dispatcher_.assignNewLabel(stmt);
7-
emitLabeledStmt(stmt, label);
8-
}
9-
105
codeql::StmtCondition StmtVisitor::translateStmtCondition(const swift::StmtCondition& cond) {
116
auto entry = dispatcher_.createEntry(cond);
127
entry.elements = dispatcher_.fetchRepeatedLabels(cond);
@@ -25,197 +20,157 @@ codeql::ConditionElement StmtVisitor::translateStmtConditionElement(
2520
return entry;
2621
}
2722

28-
void StmtVisitor::visitLabeledConditionalStmt(swift::LabeledConditionalStmt* stmt) {
29-
auto label = dispatcher_.assignNewLabel(stmt);
30-
emitLabeledStmt(stmt, label);
31-
emitLabeledConditionalStmt(stmt, label);
32-
}
33-
34-
void StmtVisitor::visitCaseLabelItem(swift::CaseLabelItem* labelItem) {
35-
auto label = dispatcher_.assignNewLabel(labelItem);
36-
assert(labelItem->getPattern() && "CaseLabelItem has Pattern");
37-
dispatcher_.emit(CaseLabelItemsTrap{label, dispatcher_.fetchLabel(labelItem->getPattern())});
23+
codeql::CaseLabelItem StmtVisitor::translateCaseLabelItem(const swift::CaseLabelItem& labelItem) {
24+
auto entry = dispatcher_.createEntry(labelItem);
25+
entry.pattern = dispatcher_.fetchLabel(labelItem.getPattern());
26+
entry.guard = dispatcher_.fetchOptionalLabel(labelItem.getGuardExpr());
27+
return entry;
3828
}
3929

40-
void StmtVisitor::visitBraceStmt(swift::BraceStmt* stmt) {
41-
auto label = dispatcher_.assignNewLabel(stmt);
42-
dispatcher_.emit(BraceStmtsTrap{label});
43-
auto i = 0u;
44-
for (auto& e : stmt->getElements()) {
45-
dispatcher_.emit(BraceStmtElementsTrap{label, i++, dispatcher_.fetchLabel(e)});
46-
}
30+
codeql::BraceStmt StmtVisitor::translateBraceStmt(const swift::BraceStmt& stmt) {
31+
auto entry = dispatcher_.createEntry(stmt);
32+
entry.elements = dispatcher_.fetchRepeatedLabels(stmt.getElements());
33+
return entry;
4734
}
4835

49-
void StmtVisitor::visitReturnStmt(swift::ReturnStmt* stmt) {
50-
auto label = dispatcher_.assignNewLabel(stmt);
51-
dispatcher_.emit(ReturnStmtsTrap{label});
52-
if (stmt->hasResult()) {
53-
auto resultLabel = dispatcher_.fetchLabel(stmt->getResult());
54-
dispatcher_.emit(ReturnStmtResultsTrap{label, resultLabel});
36+
codeql::ReturnStmt StmtVisitor::translateReturnStmt(const swift::ReturnStmt& stmt) {
37+
auto entry = dispatcher_.createEntry(stmt);
38+
if (stmt.hasResult()) {
39+
entry.result = dispatcher_.fetchLabel(stmt.getResult());
5540
}
41+
return entry;
5642
}
5743

58-
void StmtVisitor::visitForEachStmt(swift::ForEachStmt* stmt) {
59-
auto label = dispatcher_.assignNewLabel(stmt);
60-
assert(stmt->getBody() && "ForEachStmt has getBody()");
61-
assert(stmt->getParsedSequence() && "ForEachStmt has getParsedSequence()");
62-
assert(stmt->getPattern() && "ForEachStmt has getPattern()");
63-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
64-
auto sequenceLabel = dispatcher_.fetchLabel(stmt->getParsedSequence());
65-
auto patternLabel = dispatcher_.fetchLabel(stmt->getPattern());
66-
emitLabeledStmt(stmt, label);
67-
dispatcher_.emit(ForEachStmtsTrap{label, patternLabel, sequenceLabel, bodyLabel});
68-
if (auto where = stmt->getWhere()) {
69-
auto whereLabel = dispatcher_.fetchLabel(where);
70-
dispatcher_.emit(ForEachStmtWheresTrap{label, whereLabel});
71-
}
44+
codeql::ForEachStmt StmtVisitor::translateForEachStmt(const swift::ForEachStmt& stmt) {
45+
auto entry = dispatcher_.createEntry(stmt);
46+
fillLabeledStmt(stmt, entry);
47+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
48+
entry.sequence = dispatcher_.fetchLabel(stmt.getParsedSequence());
49+
entry.pattern = dispatcher_.fetchLabel(stmt.getPattern());
50+
entry.where = dispatcher_.fetchOptionalLabel(stmt.getWhere());
51+
return entry;
7252
}
7353

74-
void StmtVisitor::visitIfStmt(swift::IfStmt* stmt) {
75-
auto label = dispatcher_.assignNewLabel(stmt);
76-
emitLabeledStmt(stmt, label);
77-
emitLabeledConditionalStmt(stmt, label);
78-
auto thenLabel = dispatcher_.fetchLabel(stmt->getThenStmt());
79-
dispatcher_.emit(IfStmtsTrap{label, thenLabel});
80-
if (auto* elseStmt = stmt->getElseStmt()) {
81-
auto elseLabel = dispatcher_.fetchLabel(elseStmt);
82-
dispatcher_.emit(IfStmtElsesTrap{label, elseLabel});
83-
}
54+
codeql::IfStmt StmtVisitor::translateIfStmt(const swift::IfStmt& stmt) {
55+
auto entry = dispatcher_.createEntry(stmt);
56+
fillLabeledConditionalStmt(stmt, entry);
57+
entry.then = dispatcher_.fetchLabel(stmt.getThenStmt());
58+
entry.else_ = dispatcher_.fetchOptionalLabel(stmt.getElseStmt());
59+
return entry;
8460
}
8561

86-
void StmtVisitor::visitBreakStmt(swift::BreakStmt* stmt) {
87-
auto label = dispatcher_.assignNewLabel(stmt);
88-
dispatcher_.emit(BreakStmtsTrap{label});
89-
if (auto* target = stmt->getTarget()) {
90-
auto targetlabel = dispatcher_.fetchLabel(target);
91-
dispatcher_.emit(BreakStmtTargetsTrap{label, targetlabel});
92-
}
93-
auto targetName = stmt->getTargetName();
94-
if (!targetName.empty()) {
95-
dispatcher_.emit(BreakStmtTargetNamesTrap{label, targetName.str().str()});
62+
codeql::BreakStmt StmtVisitor::translateBreakStmt(const swift::BreakStmt& stmt) {
63+
auto entry = dispatcher_.createEntry(stmt);
64+
entry.target = dispatcher_.fetchOptionalLabel(stmt.getTarget());
65+
if (auto targetName = stmt.getTargetName(); !targetName.empty()) {
66+
entry.target_name = targetName.str().str();
9667
}
68+
return entry;
9769
}
9870

99-
void StmtVisitor::visitContinueStmt(swift::ContinueStmt* stmt) {
100-
auto label = dispatcher_.assignNewLabel(stmt);
101-
dispatcher_.emit(ContinueStmtsTrap{label});
102-
if (auto* target = stmt->getTarget()) {
103-
auto targetlabel = dispatcher_.fetchLabel(target);
104-
dispatcher_.emit(ContinueStmtTargetsTrap{label, targetlabel});
105-
}
106-
auto targetName = stmt->getTargetName();
107-
if (!targetName.empty()) {
108-
dispatcher_.emit(ContinueStmtTargetNamesTrap{label, targetName.str().str()});
71+
codeql::ContinueStmt StmtVisitor::translateContinueStmt(const swift::ContinueStmt& stmt) {
72+
auto entry = dispatcher_.createEntry(stmt);
73+
entry.target = dispatcher_.fetchOptionalLabel(stmt.getTarget());
74+
if (auto targetName = stmt.getTargetName(); !targetName.empty()) {
75+
entry.target_name = targetName.str().str();
10976
}
77+
return entry;
11078
}
11179

112-
void StmtVisitor::visitWhileStmt(swift::WhileStmt* stmt) {
113-
auto label = dispatcher_.assignNewLabel(stmt);
114-
emitLabeledStmt(stmt, label);
115-
emitLabeledConditionalStmt(stmt, label);
116-
dispatcher_.emit(WhileStmtsTrap{label, dispatcher_.fetchLabel(stmt->getBody())});
80+
codeql::WhileStmt StmtVisitor::translateWhileStmt(const swift::WhileStmt& stmt) {
81+
auto entry = dispatcher_.createEntry(stmt);
82+
fillLabeledConditionalStmt(stmt, entry);
83+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
84+
return entry;
11785
}
11886

119-
void StmtVisitor::visitRepeatWhileStmt(swift::RepeatWhileStmt* stmt) {
120-
auto label = dispatcher_.assignNewLabel(stmt);
121-
emitLabeledStmt(stmt, label);
122-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
123-
auto condLabel = dispatcher_.fetchLabel(stmt->getCond());
124-
dispatcher_.emit(RepeatWhileStmtsTrap{label, condLabel, bodyLabel});
87+
codeql::RepeatWhileStmt StmtVisitor::translateRepeatWhileStmt(const swift::RepeatWhileStmt& stmt) {
88+
auto entry = dispatcher_.createEntry(stmt);
89+
fillLabeledStmt(stmt, entry);
90+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
91+
entry.condition = dispatcher_.fetchLabel(stmt.getCond());
92+
return entry;
12593
}
12694

127-
void StmtVisitor::visitDoCatchStmt(swift::DoCatchStmt* stmt) {
128-
auto label = dispatcher_.assignNewLabel(stmt);
129-
emitLabeledStmt(stmt, label);
130-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
131-
dispatcher_.emit(DoCatchStmtsTrap{label, bodyLabel});
132-
auto i = 0u;
133-
for (auto* stmtCatch : stmt->getCatches()) {
134-
dispatcher_.emit(DoCatchStmtCatchesTrap{label, i++, dispatcher_.fetchLabel(stmtCatch)});
135-
}
95+
codeql::DoCatchStmt StmtVisitor::translateDoCatchStmt(const swift::DoCatchStmt& stmt) {
96+
auto entry = dispatcher_.createEntry(stmt);
97+
fillLabeledStmt(stmt, entry);
98+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
99+
entry.catches = dispatcher_.fetchRepeatedLabels(stmt.getCatches());
100+
return entry;
136101
}
137102

138-
void StmtVisitor::visitCaseStmt(swift::CaseStmt* stmt) {
139-
auto label = dispatcher_.assignNewLabel(stmt);
140-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
141-
dispatcher_.emit(CaseStmtsTrap{label, bodyLabel});
142-
auto i = 0u;
143-
for (auto& item : stmt->getMutableCaseLabelItems()) {
144-
dispatcher_.emit(CaseStmtLabelsTrap{label, i++, dispatcher_.fetchLabel(&item)});
145-
}
146-
if (stmt->hasCaseBodyVariables()) {
147-
auto i = 0u;
148-
for (auto* var : stmt->getCaseBodyVariables()) {
149-
dispatcher_.emit(CaseStmtVariablesTrap{label, i++, dispatcher_.fetchLabel(var)});
103+
codeql::CaseStmt StmtVisitor::translateCaseStmt(const swift::CaseStmt& stmt) {
104+
auto entry = dispatcher_.createEntry(stmt);
105+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
106+
entry.labels = dispatcher_.fetchRepeatedLabels(stmt.getCaseLabelItems());
107+
if (stmt.hasCaseBodyVariables()) {
108+
for (auto var : stmt.getCaseBodyVariables()) {
109+
entry.variables.push_back(dispatcher_.fetchLabel(var));
150110
}
151111
}
112+
return entry;
152113
}
153114

154-
void StmtVisitor::visitGuardStmt(swift::GuardStmt* stmt) {
155-
auto label = dispatcher_.assignNewLabel(stmt);
156-
emitLabeledStmt(stmt, label);
157-
emitLabeledConditionalStmt(stmt, label);
158-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
159-
dispatcher_.emit(GuardStmtsTrap{label, bodyLabel});
115+
codeql::GuardStmt StmtVisitor::translateGuardStmt(const swift::GuardStmt& stmt) {
116+
auto entry = dispatcher_.createEntry(stmt);
117+
fillLabeledConditionalStmt(stmt, entry);
118+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
119+
return entry;
160120
}
161121

162-
void StmtVisitor::visitThrowStmt(swift::ThrowStmt* stmt) {
163-
auto label = dispatcher_.assignNewLabel(stmt);
164-
auto subExprLabel = dispatcher_.fetchLabel(stmt->getSubExpr());
165-
dispatcher_.emit(ThrowStmtsTrap{label, subExprLabel});
122+
codeql::ThrowStmt StmtVisitor::translateThrowStmt(const swift::ThrowStmt& stmt) {
123+
auto entry = dispatcher_.createEntry(stmt);
124+
entry.sub_expr = dispatcher_.fetchLabel(stmt.getSubExpr());
125+
return entry;
166126
}
167127

168-
void StmtVisitor::visitDeferStmt(swift::DeferStmt* stmt) {
169-
auto label = dispatcher_.assignNewLabel(stmt);
170-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBodyAsWritten());
171-
dispatcher_.emit(DeferStmtsTrap{label, bodyLabel});
128+
codeql::DeferStmt StmtVisitor::translateDeferStmt(const swift::DeferStmt& stmt) {
129+
auto entry = dispatcher_.createEntry(stmt);
130+
entry.body = dispatcher_.fetchLabel(stmt.getBodyAsWritten());
131+
return entry;
172132
}
173133

174-
void StmtVisitor::visitDoStmt(swift::DoStmt* stmt) {
175-
auto label = dispatcher_.assignNewLabel(stmt);
176-
emitLabeledStmt(stmt, label);
177-
auto bodyLabel = dispatcher_.fetchLabel(stmt->getBody());
178-
dispatcher_.emit(DoStmtsTrap{label, bodyLabel});
134+
codeql::DoStmt StmtVisitor::translateDoStmt(const swift::DoStmt& stmt) {
135+
auto entry = dispatcher_.createEntry(stmt);
136+
fillLabeledStmt(stmt, entry);
137+
entry.body = dispatcher_.fetchLabel(stmt.getBody());
138+
return entry;
179139
}
180140

181-
void StmtVisitor::visitSwitchStmt(swift::SwitchStmt* stmt) {
182-
auto label = dispatcher_.assignNewLabel(stmt);
183-
emitLabeledStmt(stmt, label);
184-
auto subjectLabel = dispatcher_.fetchLabel(stmt->getSubjectExpr());
185-
dispatcher_.emit(SwitchStmtsTrap{label, subjectLabel});
186-
auto i = 0u;
187-
for (auto* c : stmt->getCases()) {
188-
dispatcher_.emit(SwitchStmtCasesTrap{label, i++, dispatcher_.fetchLabel(c)});
189-
}
141+
codeql::SwitchStmt StmtVisitor::translateSwitchStmt(const swift::SwitchStmt& stmt) {
142+
auto entry = dispatcher_.createEntry(stmt);
143+
fillLabeledStmt(stmt, entry);
144+
entry.expr = dispatcher_.fetchLabel(stmt.getSubjectExpr());
145+
entry.cases = dispatcher_.fetchRepeatedLabels(stmt.getCases());
146+
return entry;
190147
}
191148

192-
void StmtVisitor::visitFallthroughStmt(swift::FallthroughStmt* stmt) {
193-
auto label = dispatcher_.assignNewLabel(stmt);
194-
auto sourceLabel = dispatcher_.fetchLabel(stmt->getFallthroughSource());
195-
auto destLabel = dispatcher_.fetchLabel(stmt->getFallthroughDest());
196-
dispatcher_.emit(FallthroughStmtsTrap{label, sourceLabel, destLabel});
149+
codeql::FallthroughStmt StmtVisitor::translateFallthroughStmt(const swift::FallthroughStmt& stmt) {
150+
auto entry = dispatcher_.createEntry(stmt);
151+
entry.fallthrough_source = dispatcher_.fetchLabel(stmt.getFallthroughSource());
152+
entry.fallthrough_dest = dispatcher_.fetchLabel(stmt.getFallthroughDest());
153+
return entry;
197154
}
198155

199-
void StmtVisitor::visitYieldStmt(swift::YieldStmt* stmt) {
200-
auto label = dispatcher_.assignNewLabel(stmt);
201-
dispatcher_.emit(YieldStmtsTrap{label});
202-
auto i = 0u;
203-
for (auto* expr : stmt->getYields()) {
204-
auto exprLabel = dispatcher_.fetchLabel(expr);
205-
dispatcher_.emit(YieldStmtResultsTrap{label, i++, exprLabel});
206-
}
156+
codeql::YieldStmt StmtVisitor::translateYieldStmt(const swift::YieldStmt& stmt) {
157+
auto entry = dispatcher_.createEntry(stmt);
158+
entry.results = dispatcher_.fetchRepeatedLabels(stmt.getYields());
159+
return entry;
207160
}
208161

209-
void StmtVisitor::emitLabeledStmt(const swift::LabeledStmt* stmt, TrapLabel<LabeledStmtTag> label) {
210-
if (stmt->getLabelInfo()) {
211-
dispatcher_.emit(LabeledStmtLabelsTrap{label, stmt->getLabelInfo().Name.str().str()});
162+
void StmtVisitor::fillLabeledStmt(const swift::LabeledStmt& stmt, codeql::LabeledStmt& entry) {
163+
if (auto info = stmt.getLabelInfo()) {
164+
entry.label = info.Name.str().str();
212165
}
213166
}
214167

215-
void StmtVisitor::emitLabeledConditionalStmt(swift::LabeledConditionalStmt* stmt,
216-
TrapLabel<LabeledConditionalStmtTag> label) {
217-
auto condLabel = dispatcher_.fetchLabel(stmt->getCondPointer());
218-
dispatcher_.emit(LabeledConditionalStmtsTrap{label, condLabel});
168+
void StmtVisitor::fillLabeledConditionalStmt(const swift::LabeledConditionalStmt& stmt,
169+
codeql::LabeledConditionalStmt& entry) {
170+
// getCondPointer not provided for const stmt by swift...
171+
entry.condition =
172+
dispatcher_.fetchLabel(const_cast<swift::LabeledConditionalStmt&>(stmt).getCondPointer());
173+
fillLabeledStmt(stmt, entry);
219174
}
220175

221176
} // namespace codeql

0 commit comments

Comments
 (0)