Skip to content

Commit 2bb5603

Browse files
authored
Merge pull request github#17949 from paldepind/rust-async-blocks
Rust: Handle async blocks in CFG and SSA
2 parents 67684d1 + 274d942 commit 2bb5603

File tree

18 files changed

+825
-534
lines changed

18 files changed

+825
-534
lines changed

rust/ql/lib/codeql/rust/controlflow/BasicBlocks.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ final class BasicBlock = BasicBlockImpl;
1212
* without branches or joins.
1313
*/
1414
private class BasicBlockImpl extends TBasicBlockStart {
15-
/** Gets the scope of this basic block. */
15+
/** Gets the CFG scope of this basic block. */
1616
CfgScope getScope() { result = this.getAPredecessor().getScope() }
1717

1818
/** Gets an immediate successor of this basic block, if any. */

rust/ql/lib/codeql/rust/controlflow/internal/CfgConsistency.qll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ query predicate nonPostOrderExpr(Expr e, string cls) {
3232
*/
3333
query predicate scopeNoFirst(CfgScope scope) {
3434
Consistency::scopeNoFirst(scope) and
35-
not scope = any(Function f | not exists(f.getBody())) and
36-
not scope = any(ClosureExpr c | not exists(c.getBody()))
35+
not scope =
36+
[
37+
any(AstNode f | not f.(Function).hasBody()),
38+
any(ClosureExpr c | not c.hasBody()),
39+
any(AsyncBlockExpr b | not b.hasStmtList())
40+
]
3741
}
3842

3943
/** Holds if `be` is the `else` branch of a `let` statement that results in a panic. */

rust/ql/lib/codeql/rust/controlflow/internal/ControlFlowGraphImpl.qll

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ private module CfgInput implements InputSig<Location> {
2323
class CfgScope = Scope::CfgScope;
2424

2525
CfgScope getCfgScope(AstNode n) {
26-
result = n.getEnclosingCallable() and
26+
result = n.getEnclosingCfgScope() and
2727
Stages::CfgStage::ref()
2828
}
2929

@@ -44,12 +44,10 @@ private module CfgInput implements InputSig<Location> {
4444
predicate successorTypeIsCondition(SuccessorType t) { t instanceof Cfg::BooleanSuccessor }
4545

4646
/** Holds if `first` is first executed when entering `scope`. */
47-
predicate scopeFirst(CfgScope scope, AstNode first) {
48-
first(scope.(CfgScopeTree).getFirstChildNode(), first)
49-
}
47+
predicate scopeFirst(CfgScope scope, AstNode first) { scope.scopeFirst(first) }
5048

5149
/** Holds if `scope` is exited when `last` finishes with completion `c`. */
52-
predicate scopeLast(CfgScope scope, AstNode last, Completion c) { last(scope.getBody(), last, c) }
50+
predicate scopeLast(CfgScope scope, AstNode last, Completion c) { scope.scopeLast(last, c) }
5351
}
5452

5553
private module CfgSplittingInput implements SplittingInputSig<Location, CfgInput> {
@@ -71,14 +69,7 @@ private module CfgImpl =
7169

7270
import CfgImpl
7371

74-
class CfgScopeTree extends StandardTree, Scope::CfgScope {
75-
override predicate first(AstNode first) { first = this }
76-
77-
override predicate last(AstNode last, Completion c) {
78-
last = this and
79-
completionIsValidFor(c, this)
80-
}
81-
72+
class CallableScopeTree extends StandardTree, PreOrderTree, PostOrderTree, Scope::CallableScope {
8273
override predicate propagatesAbnormal(AstNode child) { none() }
8374

8475
override AstNode getChildNode(int i) {
@@ -280,13 +271,23 @@ module ExprTrees {
280271
}
281272
}
282273

274+
private AstNode getBlockChildNode(BlockExpr b, int i) {
275+
result = b.getStmtList().getStatement(i)
276+
or
277+
i = b.getStmtList().getNumberOfStatements() and
278+
result = b.getStmtList().getTailExpr()
279+
}
280+
281+
class AsyncBlockExprTree extends StandardTree, PreOrderTree, PostOrderTree, AsyncBlockExpr {
282+
override AstNode getChildNode(int i) { result = getBlockChildNode(this, i) }
283+
284+
override predicate propagatesAbnormal(AstNode child) { none() }
285+
}
286+
283287
class BlockExprTree extends StandardPostOrderTree, BlockExpr {
284-
override AstNode getChildNode(int i) {
285-
result = this.getStmtList().getStatement(i)
286-
or
287-
i = this.getStmtList().getNumberOfStatements() and
288-
result = this.getStmtList().getTailExpr()
289-
}
288+
BlockExprTree() { not this.isAsync() }
289+
290+
override AstNode getChildNode(int i) { result = getBlockChildNode(this, i) }
290291

291292
override predicate propagatesAbnormal(AstNode child) { child = this.getChildNode(_) }
292293
}

rust/ql/lib/codeql/rust/controlflow/internal/Scope.qll

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,31 @@ private import codeql.rust.elements.internal.generated.ParentChild
55

66
/**
77
* A control-flow graph (CFG) scope.
8-
*
9-
* A CFG scope is a callable with a body.
108
*/
11-
class CfgScope extends Callable {
12-
CfgScope() {
9+
abstract private class CfgScopeImpl extends AstNode {
10+
/** Holds if `first` is executed first when entering `scope`. */
11+
abstract predicate scopeFirst(AstNode first);
12+
13+
/** Holds if `scope` is exited when `last` finishes with completion `c`. */
14+
abstract predicate scopeLast(AstNode last, Completion c);
15+
}
16+
17+
final class CfgScope = CfgScopeImpl;
18+
19+
final class AsyncBlockScope extends CfgScopeImpl, AsyncBlockExpr instanceof ExprTrees::AsyncBlockExprTree
20+
{
21+
override predicate scopeFirst(AstNode first) { first(super.getFirstChildNode(), first) }
22+
23+
override predicate scopeLast(AstNode last, Completion c) {
24+
last(super.getLastChildElement(), last, c)
25+
}
26+
}
27+
28+
/**
29+
* A CFG scope for a callable (a function or a closure) with a body.
30+
*/
31+
final class CallableScope extends CfgScopeImpl, Callable {
32+
CallableScope() {
1333
// A function without a body corresponds to a trait method signature and
1434
// should not have a CFG scope.
1535
this.(Function).hasBody()
@@ -23,4 +43,11 @@ class CfgScope extends Callable {
2343
or
2444
result = this.(ClosureExpr).getBody()
2545
}
46+
47+
override predicate scopeFirst(AstNode first) {
48+
first(this.(CallableScopeTree).getFirstChildNode(), first)
49+
}
50+
51+
/** Holds if `scope` is exited when `last` finishes with completion `c`. */
52+
override predicate scopeLast(AstNode last, Completion c) { last(this.getBody(), last, c) }
2653
}

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ final class NormalCall extends DataFlowCall, TNormalCall {
6565
override CallCfgNode asCall() { result = c }
6666

6767
override DataFlowCallable getEnclosingCallable() {
68-
result = TCfgScope(c.getExpr().getEnclosingCallable())
68+
result = TCfgScope(c.getExpr().getEnclosingCfgScope())
6969
}
7070

7171
override string toString() { result = c.toString() }
@@ -136,7 +136,7 @@ module Node {
136136

137137
ExprNode() { this = TExprNode(n) }
138138

139-
override CfgScope getCfgScope() { result = this.asExpr().getEnclosingCallable() }
139+
override CfgScope getCfgScope() { result = this.asExpr().getEnclosingCfgScope() }
140140

141141
override Location getLocation() { result = n.getExpr().getLocation() }
142142

@@ -156,7 +156,7 @@ module Node {
156156

157157
ParameterNode() { this = TParameterNode(parameter) }
158158

159-
override CfgScope getCfgScope() { result = parameter.getParam().getEnclosingCallable() }
159+
override CfgScope getCfgScope() { result = parameter.getParam().getEnclosingCfgScope() }
160160

161161
override Location getLocation() { result = parameter.getLocation() }
162162

rust/ql/lib/codeql/rust/dataflow/internal/SsaImpl.qll

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,19 @@ private predicate variableReadActual(BasicBlock bb, int i, Variable v) {
156156
*/
157157
pragma[noinline]
158158
private predicate hasCapturedWrite(Variable v, Cfg::CfgScope scope) {
159-
any(VariableWriteAccess write | write.getVariable() = v and scope = write.getEnclosingCallable+())
159+
any(VariableWriteAccess write | write.getVariable() = v and scope = write.getEnclosingCfgScope+())
160160
.isCapture()
161161
}
162162

163163
/**
164164
* Holds if `v` is read inside basic block `bb` at index `i`, which is in the
165-
* immediate outer scope of `scope`.
165+
* immediate outer CFG scope of `scope`.
166166
*/
167167
pragma[noinline]
168168
private predicate variableReadActualInOuterScope(
169169
BasicBlock bb, int i, Variable v, Cfg::CfgScope scope
170170
) {
171-
variableReadActual(bb, i, v) and bb.getScope() = scope.getEnclosingCallable()
171+
variableReadActual(bb, i, v) and bb.getScope() = scope.getEnclosingCfgScope()
172172
}
173173

174174
pragma[noinline]
@@ -263,7 +263,7 @@ private predicate readsCapturedVariable(BasicBlock bb, Variable v) {
263263
*/
264264
pragma[noinline]
265265
private predicate hasCapturedRead(Variable v, Cfg::CfgScope scope) {
266-
any(VariableReadAccess read | read.getVariable() = v and scope = read.getEnclosingCallable+())
266+
any(VariableReadAccess read | read.getVariable() = v and scope = read.getEnclosingCfgScope+())
267267
.isCapture()
268268
}
269269

@@ -273,14 +273,18 @@ private predicate hasCapturedRead(Variable v, Cfg::CfgScope scope) {
273273
*/
274274
pragma[noinline]
275275
private predicate variableWriteInOuterScope(BasicBlock bb, int i, Variable v, Cfg::CfgScope scope) {
276-
SsaInput::variableWrite(bb, i, v, _) and scope.getEnclosingCallable() = bb.getScope()
276+
SsaInput::variableWrite(bb, i, v, _) and scope.getEnclosingCfgScope() = bb.getScope()
277277
}
278278

279+
/** Holds if evaluating `e` jumps to the evaluation of a different CFG scope. */
280+
private predicate isControlFlowJump(Expr e) { e instanceof CallExprBase or e instanceof AwaitExpr }
281+
279282
/**
280283
* Holds if the call `call` at index `i` in basic block `bb` may reach
281284
* a callable that reads captured variable `v`.
282285
*/
283-
private predicate capturedCallRead(CallExprBase call, BasicBlock bb, int i, Variable v) {
286+
private predicate capturedCallRead(Expr call, BasicBlock bb, int i, Variable v) {
287+
isControlFlowJump(call) and
284288
exists(Cfg::CfgScope scope |
285289
hasCapturedRead(v, scope) and
286290
(
@@ -295,7 +299,8 @@ private predicate capturedCallRead(CallExprBase call, BasicBlock bb, int i, Vari
295299
* Holds if the call `call` at index `i` in basic block `bb` may reach a callable
296300
* that writes captured variable `v`.
297301
*/
298-
predicate capturedCallWrite(CallExprBase call, BasicBlock bb, int i, Variable v) {
302+
predicate capturedCallWrite(Expr call, BasicBlock bb, int i, Variable v) {
303+
isControlFlowJump(call) and
299304
call = bb.getNode(i).getAstNode() and
300305
exists(Cfg::CfgScope scope |
301306
hasVariableReadWithCapturedWrite(bb, any(int j | j > i), v, scope)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/**
2+
* This module provides the public class `AsyncBlockExpr`.
3+
*/
4+
5+
private import codeql.rust.elements.BlockExpr
6+
7+
/**
8+
* An async block expression. For example:
9+
* ```rust
10+
* async {
11+
* let x = 42;
12+
* }
13+
* ```
14+
*/
15+
final class AsyncBlockExpr extends BlockExpr {
16+
AsyncBlockExpr() { this.isAsync() }
17+
}

rust/ql/lib/codeql/rust/elements/internal/AstNodeImpl.qll

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66

77
private import codeql.rust.elements.internal.generated.AstNode
8+
private import codeql.rust.controlflow.ControlFlowGraph
89

910
/**
1011
* INTERNAL: This module contains the customizable definition of `AstNode` and should not
@@ -44,6 +45,17 @@ module Impl {
4445
)
4546
}
4647

48+
/** Gets the CFG scope that encloses this node, if any. */
49+
cached
50+
CfgScope getEnclosingCfgScope() {
51+
exists(AstNode p | p = this.getParentNode() |
52+
result = p
53+
or
54+
not p instanceof CfgScope and
55+
result = p.getEnclosingCfgScope()
56+
)
57+
}
58+
4759
/** Holds if this node is inside a macro expansion. */
4860
predicate isInMacroExpansion() {
4961
this = any(MacroCall mc).getExpanded()

rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
private import rust
2+
private import codeql.rust.controlflow.ControlFlowGraph
23
private import codeql.rust.elements.internal.generated.ParentChild
34
private import codeql.rust.elements.internal.PathExprBaseImpl::Impl as PathExprBaseImpl
45
private import codeql.rust.elements.internal.FormatTemplateVariableAccessImpl::Impl as FormatTemplateVariableAccessImpl
@@ -445,7 +446,7 @@ module Impl {
445446
Variable getVariable() { result = v }
446447

447448
/** Holds if this access is a capture. */
448-
predicate isCapture() { this.getEnclosingCallable() != v.getPat().getEnclosingCallable() }
449+
predicate isCapture() { this.getEnclosingCfgScope() != v.getPat().getEnclosingCfgScope() }
449450

450451
override string toString() { result = name }
451452

rust/ql/lib/rust.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import codeql.Locations
55
import codeql.files.FileSystem
66
import codeql.rust.elements.AssignmentOperation
77
import codeql.rust.elements.LogicalOperation
8+
import codeql.rust.elements.AsyncBlockExpr
89
import codeql.rust.elements.Variable
910
import codeql.rust.elements.NamedFormatArgument
1011
import codeql.rust.elements.PositionalFormatArgument

0 commit comments

Comments
 (0)