Skip to content

Commit 821f2fd

Browse files
committed
Rust: Type inference for .await expressions
1 parent e6109cf commit 821f2fd

File tree

6 files changed

+229
-14
lines changed

6 files changed

+229
-14
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,19 @@ class ResultEnum extends Enum {
4949
/** Gets the `Err` variant. */
5050
Variant getErr() { result = this.getVariant("Err") }
5151
}
52+
53+
/**
54+
* The [`Future` trait][1].
55+
*
56+
* [1]: https://doc.rust-lang.org/std/future/trait.Future.html
57+
*/
58+
class FutureTrait extends Trait {
59+
FutureTrait() { this.getCanonicalPath() = "core::future::future::Future" }
60+
61+
/** Gets the `Output` associated type. */
62+
pragma[nomagic]
63+
TypeAlias getOutputType() {
64+
result = this.getAssocItemList().getAnAssocItem() and
65+
result.getName().getText() = "Output"
66+
}
67+
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ newtype TType =
1515
TTrait(Trait t) or
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
18+
TImplTraitType(int bounds) {
19+
bounds = any(ImplTraitTypeRepr impl).getTypeBoundList().getNumberOfBounds()
20+
} or
1821
TTypeParamTypeParameter(TypeParam t) or
1922
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2023
TRefTypeParameter() or
21-
TSelfTypeParameter(Trait t)
24+
TSelfTypeParameter(Trait t) or
25+
TImplTraitTypeParameter(ImplTraitType t, int i) { i in [0 .. t.getNumberOfBounds() - 1] }
2226

2327
/**
2428
* A type without type arguments.
@@ -115,6 +119,9 @@ class TraitType extends Type, TTrait {
115119

116120
TraitType() { this = TTrait(trait) }
117121

122+
/** Gets the underlying trait. */
123+
Trait getTrait() { result = trait }
124+
118125
override StructField getStructField(string name) { none() }
119126

120127
override TupleField getTupleField(int i) { none() }
@@ -176,6 +183,33 @@ class RefType extends Type, TRefType {
176183
override Location getLocation() { result instanceof EmptyLocation }
177184
}
178185

186+
/**
187+
* An [`impl Trait`][1] type.
188+
*
189+
* We represent `impl Trait` types as generic types with as many type parameters
190+
* as there are bounds.
191+
*
192+
* [1] https://doc.rust-lang.org/book/ch10-02-traits.html#traits-as-parameters
193+
*/
194+
class ImplTraitType extends Type, TImplTraitType {
195+
private int bounds;
196+
197+
ImplTraitType() { this = TImplTraitType(bounds) }
198+
199+
/** Gets the number of bounds of this `impl Trait` type. */
200+
int getNumberOfBounds() { result = bounds }
201+
202+
override StructField getStructField(string name) { none() }
203+
204+
override TupleField getTupleField(int i) { none() }
205+
206+
override TypeParameter getTypeParameter(int i) { result = TImplTraitTypeParameter(this, i) }
207+
208+
override string toString() { result = "impl Trait ..." }
209+
210+
override Location getLocation() { result instanceof EmptyLocation }
211+
}
212+
179213
/** A type parameter. */
180214
abstract class TypeParameter extends Type {
181215
override StructField getStructField(string name) { none() }
@@ -281,6 +315,26 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
281315
override Location getLocation() { result = trait.getLocation() }
282316
}
283317

318+
/**
319+
* An `impl Trait` type parameter.
320+
*/
321+
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
322+
private ImplTraitType implTraitType;
323+
private int i;
324+
325+
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTraitType, i) }
326+
327+
/** Gets the `impl Trait` type that this parameter belongs to. */
328+
ImplTraitType getImplTraitType() { result = implTraitType }
329+
330+
/** Gets the index of this type parameter. */
331+
int getIndex() { result = i }
332+
333+
override string toString() { result = "impl Trait<" + i.toString() + ">" }
334+
335+
override Location getLocation() { result instanceof EmptyLocation }
336+
}
337+
284338
/**
285339
* A type abstraction. I.e., a place in the program where type variables are
286340
* introduced.

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ private module Input1 implements InputSig1<Location> {
7777
apos.asMethodTypeArgumentPosition() = ppos.asTypeParam().getPosition()
7878
}
7979

80+
private int getImplTraitTypeParameterId(ImplTraitTypeParameter tp) {
81+
tp =
82+
rank[result](ImplTraitTypeParameter tp0, int bounds, int i |
83+
bounds = tp0.getImplTraitType().getNumberOfBounds() and
84+
i = tp0.getIndex()
85+
|
86+
tp0 order by bounds, i
87+
)
88+
}
89+
8090
int getTypeParameterId(TypeParameter tp) {
8191
tp =
8292
rank[result](TypeParameter tp0, int kind, int id |
@@ -90,6 +100,9 @@ private module Input1 implements InputSig1<Location> {
90100
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
91101
node = tp0.(SelfTypeParameter).getTrait()
92102
)
103+
or
104+
kind = 2 and
105+
id = getImplTraitTypeParameterId(tp0)
93106
|
94107
tp0 order by kind, id
95108
)
@@ -228,7 +241,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
228241
or
229242
n1 = n2.(ParenExpr).getExpr()
230243
or
231-
n1 = n2.(BlockExpr).getStmtList().getTailExpr()
244+
n2 =
245+
any(BlockExpr be |
246+
not be.isAsync() and
247+
n1 = be.getStmtList().getTailExpr()
248+
)
232249
or
233250
n1 = n2.(IfExpr).getABranch()
234251
or
@@ -1010,6 +1027,29 @@ private StructType inferLiteralType(LiteralExpr le) {
10101027
)
10111028
}
10121029

1030+
pragma[nomagic]
1031+
private AssociatedTypeTypeParameter getFutureOutputTypeParameter() {
1032+
result.getTypeAlias() = any(FutureTrait ft).getOutputType()
1033+
}
1034+
1035+
pragma[nomagic]
1036+
private Type inferAwaitExprType(AwaitExpr ae, TypePath path) {
1037+
exists(TypePath exprPath | result = inferType(ae.getExpr(), exprPath) |
1038+
exprPath
1039+
.isCons(TImplTraitTypeParameter(_, _),
1040+
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), path)))
1041+
or
1042+
path = exprPath and
1043+
not (
1044+
exprPath = TypePath::singleton(TImplTraitTypeParameter(_, _)) and
1045+
result.(TraitType).getTrait() instanceof FutureTrait
1046+
) and
1047+
not exprPath
1048+
.isCons(TImplTraitTypeParameter(_, _),
1049+
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), _)))
1050+
)
1051+
}
1052+
10131053
private module MethodCall {
10141054
/** An expression that calls a method. */
10151055
abstract private class MethodCallImpl extends Expr {
@@ -1119,12 +1159,17 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
11191159
}
11201160

11211161
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1162+
pragma[nomagic]
1163+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1164+
rootType = mc.getTypeAt(TypePath::nil()) and
1165+
name = mc.getMethodName() and
1166+
arity = mc.getArity()
1167+
}
1168+
11221169
pragma[nomagic]
11231170
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
11241171
exists(Type rootType, string name, int arity |
1125-
rootType = mc.getTypeAt(TypePath::nil()) and
1126-
name = mc.getMethodName() and
1127-
arity = mc.getArity() and
1172+
isMethodCall(mc, rootType, name, arity) and
11281173
constraint = impl.(ImplTypeAbstraction).getSelfTy()
11291174
|
11301175
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
@@ -1161,6 +1206,12 @@ private Function getMethodFromImpl(MethodCall mc) {
11611206
)
11621207
}
11631208

1209+
bindingset[trait, name]
1210+
pragma[inline_late]
1211+
private Function getTraitMethod(TraitType trait, string name) {
1212+
result = getMethodSuccessor(trait.getTrait(), name)
1213+
}
1214+
11641215
/**
11651216
* Gets a method that the method call `mc` resolves to based on type inference,
11661217
* if any.
@@ -1172,6 +1223,11 @@ private Function inferMethodCallTarget(MethodCall mc) {
11721223
// The type of the receiver is a type parameter and the method comes from a
11731224
// trait bound on the type parameter.
11741225
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1226+
or
1227+
// The type of the receiver is an `impl Trait` type.
1228+
result =
1229+
getTraitMethod(mc.getTypeAt(TypePath::singleton(TImplTraitTypeParameter(_, _))),
1230+
mc.getMethodName())
11751231
}
11761232

11771233
cached
@@ -1347,6 +1403,8 @@ private module Cached {
13471403
or
13481404
result = inferLiteralType(n) and
13491405
path.isEmpty()
1406+
or
1407+
result = inferAwaitExprType(n, path)
13501408
}
13511409
}
13521410

@@ -1363,7 +1421,7 @@ private module Debug {
13631421
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
13641422
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
13651423
filepath.matches("%/main.rs") and
1366-
startline = 948
1424+
startline = 1334
13671425
)
13681426
}
13691427

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ abstract class TypeMention extends AstNode {
1515

1616
/** Gets the sub mention at `path`. */
1717
pragma[nomagic]
18-
private TypeMention getMentionAt(TypePath path) {
18+
TypeMention getMentionAt(TypePath path) {
1919
path.isEmpty() and
2020
result = this
2121
or
@@ -150,6 +150,54 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
150150
not exists(resolved.(TypeAlias).getTypeRepr()) and
151151
result = super.resolveTypeAt(typePath)
152152
}
153+
154+
pragma[nomagic]
155+
private TypeAlias getResolvedTraitAlias(string name) {
156+
exists(TraitItemNode trait |
157+
trait = resolvePath(path) and
158+
result = trait.getAnAssocItem() and
159+
name = result.getName().getText()
160+
)
161+
}
162+
163+
pragma[nomagic]
164+
private TypeRepr getAssocTypeArg(string name) {
165+
exists(AssocTypeArg arg |
166+
arg = path.getSegment().getGenericArgList().getAGenericArg() and
167+
result = arg.getTypeRepr() and
168+
name = arg.getIdentifier().getText()
169+
)
170+
}
171+
172+
/** Gets the type argument for the associated type `alias`, if any. */
173+
pragma[nomagic]
174+
private TypeRepr getAnAssocTypeArgument(TypeAlias alias) {
175+
exists(string name |
176+
alias = this.getResolvedTraitAlias(name) and
177+
result = this.getAssocTypeArg(name)
178+
)
179+
}
180+
181+
override TypeMention getMentionAt(TypePath tp) {
182+
result = super.getMentionAt(tp)
183+
or
184+
exists(TypeAlias alias, AssociatedTypeTypeParameter attp, TypeMention arg, TypePath suffix |
185+
arg = this.getAnAssocTypeArgument(alias) and
186+
result = arg.getMentionAt(suffix) and
187+
tp = TypePath::cons(attp, suffix) and
188+
attp.getTypeAlias() = alias
189+
)
190+
}
191+
}
192+
193+
class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr {
194+
override TypeMention getTypeArgument(int i) {
195+
result = super.getTypeBoundList().getBound(i).getTypeRepr()
196+
}
197+
198+
override ImplTraitType resolveType() {
199+
result.getNumberOfBounds() = super.getTypeBoundList().getNumberOfBounds()
200+
}
153201
}
154202

155203
private TypeParameter pathGetTypeParameter(TypeAlias alias, int i) {

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,9 +1664,9 @@ mod async_ {
16641664
}
16651665

16661666
pub async fn f() {
1667-
f1().await.f(); // $ MISSING: method=S1f
1668-
f2().await.f(); // $ MISSING: method=S1f
1669-
f3().await.f(); // $ MISSING: method=S1f
1667+
f1().await.f(); // $ method=S1f
1668+
f2().await.f(); // $ method=S1f
1669+
f3().await.f(); // $ method=S1f
16701670
}
16711671
}
16721672

@@ -1696,8 +1696,8 @@ mod impl_trait {
16961696

16971697
pub fn f() {
16981698
let x = f1();
1699-
x.f1(); // $ MISSING: method=Trait1f1
1700-
x.f2(); // $ MISSING: method=Trait2f2
1699+
x.f1(); // $ method=Trait1f1
1700+
x.f2(); // $ method=Trait2f2
17011701
}
17021702
}
17031703

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,8 +2377,12 @@ inferType
23772377
| main.rs:1639:18:1639:21 | SelfParam | | main.rs:1636:5:1636:14 | S1 |
23782378
| main.rs:1642:25:1644:5 | { ... } | | main.rs:1636:5:1636:14 | S1 |
23792379
| main.rs:1643:9:1643:10 | S1 | | main.rs:1636:5:1636:14 | S1 |
2380-
| main.rs:1646:41:1650:5 | { ... } | | main.rs:1636:5:1636:14 | S1 |
2381-
| main.rs:1647:9:1649:9 | { ... } | | main.rs:1636:5:1636:14 | S1 |
2380+
| main.rs:1646:41:1650:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
2381+
| main.rs:1646:41:1650:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2382+
| main.rs:1646:41:1650:5 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
2383+
| main.rs:1647:9:1649:9 | { ... } | | file://:0:0:0:0 | impl Trait ... |
2384+
| main.rs:1647:9:1649:9 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2385+
| main.rs:1647:9:1649:9 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
23822386
| main.rs:1648:13:1648:14 | S1 | | main.rs:1636:5:1636:14 | S1 |
23832387
| main.rs:1657:17:1657:46 | SelfParam | | {EXTERNAL LOCATION} | Pin |
23842388
| main.rs:1657:17:1657:46 | SelfParam | Ptr | file://:0:0:0:0 | & |
@@ -2390,9 +2394,26 @@ inferType
23902394
| main.rs:1658:13:1658:38 | ...::Ready(...) | | {EXTERNAL LOCATION} | Poll |
23912395
| main.rs:1658:13:1658:38 | ...::Ready(...) | T | main.rs:1636:5:1636:14 | S1 |
23922396
| main.rs:1658:36:1658:37 | S1 | | main.rs:1636:5:1636:14 | S1 |
2397+
| main.rs:1662:41:1664:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
23932398
| main.rs:1662:41:1664:5 | { ... } | | main.rs:1652:5:1652:14 | S2 |
2399+
| main.rs:1662:41:1664:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2400+
| main.rs:1662:41:1664:5 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
2401+
| main.rs:1663:9:1663:10 | S2 | | file://:0:0:0:0 | impl Trait ... |
23942402
| main.rs:1663:9:1663:10 | S2 | | main.rs:1652:5:1652:14 | S2 |
2403+
| main.rs:1663:9:1663:10 | S2 | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2404+
| main.rs:1663:9:1663:10 | S2 | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
23952405
| main.rs:1667:9:1667:12 | f1(...) | | main.rs:1636:5:1636:14 | S1 |
2406+
| main.rs:1667:9:1667:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
2407+
| main.rs:1668:9:1668:12 | f2(...) | | file://:0:0:0:0 | impl Trait ... |
2408+
| main.rs:1668:9:1668:12 | f2(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2409+
| main.rs:1668:9:1668:12 | f2(...) | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
2410+
| main.rs:1668:9:1668:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
2411+
| main.rs:1668:9:1668:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
2412+
| main.rs:1669:9:1669:12 | f3(...) | | file://:0:0:0:0 | impl Trait ... |
2413+
| main.rs:1669:9:1669:12 | f3(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
2414+
| main.rs:1669:9:1669:12 | f3(...) | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
2415+
| main.rs:1669:9:1669:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
2416+
| main.rs:1669:9:1669:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
23962417
| main.rs:1678:15:1678:19 | SelfParam | | file://:0:0:0:0 | & |
23972418
| main.rs:1678:15:1678:19 | SelfParam | &T | main.rs:1677:5:1679:5 | Self [trait Trait1] |
23982419
| main.rs:1682:15:1682:19 | SelfParam | | file://:0:0:0:0 | & |
@@ -2401,8 +2422,26 @@ inferType
24012422
| main.rs:1686:15:1686:19 | SelfParam | &T | main.rs:1675:5:1675:14 | S1 |
24022423
| main.rs:1690:15:1690:19 | SelfParam | | file://:0:0:0:0 | & |
24032424
| main.rs:1690:15:1690:19 | SelfParam | &T | main.rs:1675:5:1675:14 | S1 |
2425+
| main.rs:1693:37:1695:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
24042426
| main.rs:1693:37:1695:5 | { ... } | | main.rs:1675:5:1675:14 | S1 |
2427+
| main.rs:1693:37:1695:5 | { ... } | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2428+
| main.rs:1693:37:1695:5 | { ... } | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
2429+
| main.rs:1694:9:1694:10 | S1 | | file://:0:0:0:0 | impl Trait ... |
24052430
| main.rs:1694:9:1694:10 | S1 | | main.rs:1675:5:1675:14 | S1 |
2431+
| main.rs:1694:9:1694:10 | S1 | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2432+
| main.rs:1694:9:1694:10 | S1 | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
2433+
| main.rs:1698:13:1698:13 | x | | file://:0:0:0:0 | impl Trait ... |
2434+
| main.rs:1698:13:1698:13 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2435+
| main.rs:1698:13:1698:13 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
2436+
| main.rs:1698:17:1698:20 | f1(...) | | file://:0:0:0:0 | impl Trait ... |
2437+
| main.rs:1698:17:1698:20 | f1(...) | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2438+
| main.rs:1698:17:1698:20 | f1(...) | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
2439+
| main.rs:1699:9:1699:9 | x | | file://:0:0:0:0 | impl Trait ... |
2440+
| main.rs:1699:9:1699:9 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2441+
| main.rs:1699:9:1699:9 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
2442+
| main.rs:1700:9:1700:9 | x | | file://:0:0:0:0 | impl Trait ... |
2443+
| main.rs:1700:9:1700:9 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
2444+
| main.rs:1700:9:1700:9 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
24062445
| main.rs:1706:5:1706:20 | ...::f(...) | | main.rs:67:5:67:21 | Foo |
24072446
| main.rs:1707:5:1707:60 | ...::g(...) | | main.rs:67:5:67:21 | Foo |
24082447
| main.rs:1707:20:1707:38 | ...::Foo {...} | | main.rs:67:5:67:21 | Foo |

0 commit comments

Comments
 (0)