Skip to content

Commit b01c74c

Browse files
committed
Rust: Implement basic type inference in QL
1 parent d56bf65 commit b01c74c

File tree

3 files changed

+332
-4
lines changed

3 files changed

+332
-4
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallExprBaseImpl.qll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module Impl {
1717
private import codeql.rust.elements.internal.CallExprImpl::Impl
1818
private import codeql.rust.elements.internal.PathExprImpl::Impl
1919
private import codeql.rust.elements.internal.PathResolution
20+
private import codeql.rust.elements.internal.TypeInference
2021

2122
pragma[nomagic]
2223
Resolvable getCallResolvable(CallExprBase call) {
@@ -35,9 +36,11 @@ module Impl {
3536
* be statically resolved.
3637
*/
3738
Callable getStaticTarget() {
38-
getCallResolvable(this).resolvesAsItem(result)
39-
or
39+
// getCallResolvable(this).resolvesAsItem(result)
40+
// or
4041
result = resolvePath(this.(CallExpr).getFunction().(PathExpr).getPath())
42+
or
43+
result = resolveMethodCallExpr(this)
4144
}
4245
}
4346
}

rust/ql/lib/codeql/rust/elements/internal/PathResolution.qll

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ abstract private class ImplOrTraitItemNode extends ItemNode {
184184
}
185185
}
186186

187-
private class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
187+
class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
188188
override string getName() { result = "(impl)" }
189189

190-
override Visibility getVisibility() { none() }
190+
override Visibility getVisibility() { result = Impl.super.getVisibility() }
191191
}
192192

193193
private class MacroCallItemNode extends ItemNode instanceof MacroCall {
@@ -232,6 +232,12 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
232232
override Visibility getVisibility() { none() }
233233
}
234234

235+
private class TypeParamItemNode extends ItemNode instanceof TypeParam {
236+
override string getName() { result = TypeParam.super.getName().getText() }
237+
238+
override Visibility getVisibility() { none() }
239+
}
240+
235241
/** Holds if `item` has the name `name` and is a top-level item inside `f`. */
236242
private predicate sourceFileEdge(SourceFile f, string name, ItemNode item) {
237243
item = f.getAnItem() and
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/** Provides functionality for inferring types. */
2+
3+
private import rust
4+
private import PathResolution
5+
6+
/** Gets the singleton type path `i`. */
7+
bindingset[i]
8+
private TypePath typePath(int i) { result = i.toString() }
9+
10+
bindingset[s]
11+
private predicate decodeTypePathComponent(string s, int i) { i = s.toInt() }
12+
13+
/**
14+
* A path into a (constructed) type.
15+
*
16+
* Paths are represented in left-to-right order, for example, a path `0.1` into the
17+
* type `C1<C2<A,B>,C3<C,D>>` points at the type `B`.
18+
*
19+
* Type paths are used to represent constructed types without using a `newtype`, which
20+
* makes it practically feasible to do type inference in mutual recursion with call
21+
* resolution.
22+
*
23+
* As an example, the type above can be represented by the following set of tuples
24+
*
25+
* `TypePath` | `Type`
26+
* ---------- | ------
27+
* `""` | ``C1``
28+
* `"0"` | ``C2``
29+
* `"0.0"` | `A`
30+
* `"0.1"` | `B`
31+
* `"1"` | ``C3``
32+
* `"1.0"` | `C`
33+
* `"1.1"` | `D`
34+
*/
35+
class TypePath extends string {
36+
bindingset[this]
37+
TypePath() { exists(this) }
38+
39+
predicate isEmpty() { this = "" }
40+
41+
/** Gets the path obtained by appending `suffix` onto this path. */
42+
bindingset[suffix, result]
43+
bindingset[this, result]
44+
bindingset[this, suffix]
45+
TypePath append(TypePath suffix) {
46+
if this.isEmpty()
47+
then result = suffix
48+
else
49+
if suffix.isEmpty()
50+
then result = this
51+
else result = this + "." + suffix
52+
}
53+
54+
/** Holds if this path starts with `prefix`, followed by `i`. */
55+
bindingset[this]
56+
predicate endsWith(TypePath prefix, int i) {
57+
decodeTypePathComponent(this, i) and
58+
prefix.isEmpty()
59+
or
60+
exists(int last |
61+
last = max(this.indexOf(".")) and
62+
prefix = this.prefix(last) and
63+
decodeTypePathComponent(this.suffix(last + 1), i)
64+
)
65+
}
66+
67+
/** Holds if this path starts with `i`, followed by `suffix`. */
68+
bindingset[this]
69+
predicate startsWith(int i, TypePath suffix) {
70+
decodeTypePathComponent(this, i) and
71+
suffix.isEmpty()
72+
or
73+
exists(int first |
74+
first = min(this.indexOf(".")) and
75+
suffix = this.suffix(first + 1) and
76+
decodeTypePathComponent(this.prefix(first), i)
77+
)
78+
}
79+
}
80+
81+
private predicate letStmtTyped(LetStmt let, Pat pat, TypeRepr t) {
82+
pat = let.getPat() and
83+
t = let.getTypeRepr() and
84+
not t instanceof InferTypeRepr
85+
}
86+
87+
private predicate paramTyped(Param p, Pat pat, TypeRepr t) {
88+
pat = p.getPat() and
89+
t = p.getTypeRepr() and
90+
not t instanceof InferTypeRepr
91+
}
92+
93+
private predicate isTargetTyped(AstNode n) {
94+
exists(Variable v |
95+
n = v.getPat() and
96+
not letStmtTyped(_, n, _) and
97+
not paramTyped(_, n, _)
98+
)
99+
}
100+
101+
// todo: add more cases
102+
private newtype TType =
103+
TStruct(Struct s) or
104+
TEnum(Enum e) or
105+
TArrayType() or // todo: add size?
106+
TRefType() or // todo: add mut, lifetime?
107+
TTypeParameter(TypeParam t)
108+
109+
abstract private class Type extends TType {
110+
pragma[nomagic]
111+
abstract Function getMethod(string name);
112+
113+
abstract string toString();
114+
115+
abstract Location getLocation();
116+
}
117+
118+
class StructType extends Type {
119+
private Struct struct;
120+
121+
StructType() { this = TStruct(struct) }
122+
123+
override Function getMethod(string name) {
124+
exists(ImplItemNode i |
125+
struct = resolvePath(i.(Impl).getSelfTy().(PathTypeRepr).getPath()) and
126+
result = i.getASuccessor(name)
127+
)
128+
}
129+
130+
override string toString() { result = struct.toString() }
131+
132+
override Location getLocation() { result = struct.getLocation() }
133+
}
134+
135+
class EnumType extends Type {
136+
private Enum enum;
137+
138+
EnumType() { this = TEnum(enum) }
139+
140+
override Function getMethod(string name) {
141+
exists(ImplItemNode i |
142+
enum = resolvePath(i.(Impl).getSelfTy().(PathTypeRepr).getPath()) and
143+
result = i.getASuccessor(name)
144+
)
145+
}
146+
147+
override string toString() { result = enum.toString() }
148+
149+
override Location getLocation() { result = enum.getLocation() }
150+
}
151+
152+
class ArrayType extends Type {
153+
ArrayType() { this = TArrayType() }
154+
155+
override Function getMethod(string name) { none() }
156+
157+
override string toString() { result = "[]" }
158+
159+
override Location getLocation() { result instanceof EmptyLocation }
160+
}
161+
162+
class RefType extends Type {
163+
RefType() { this = TRefType() }
164+
165+
override Function getMethod(string name) { none() }
166+
167+
override string toString() { result = "&" }
168+
169+
override Location getLocation() { result instanceof EmptyLocation }
170+
}
171+
172+
class TypeParameter extends Type {
173+
private TypeParam typeParam;
174+
175+
TypeParameter() { this = TTypeParameter(typeParam) }
176+
177+
override Function getMethod(string name) { none() }
178+
179+
override string toString() { result = typeParam.toString() }
180+
181+
override Location getLocation() { result = typeParam.getLocation() }
182+
}
183+
184+
abstract private class TypeReprOrPath extends AstNode {
185+
TypeRepr getTypeReprArgument(int i) {
186+
result = this.(ArrayTypeRepr).getElementTypeRepr() and
187+
i = 0
188+
or
189+
result = this.(RefTypeRepr).getTypeRepr() and
190+
i = 0
191+
or
192+
result = this.(Path).getPart().getGenericArgList().getGenericArg(i).(TypeArg).getTypeRepr()
193+
or
194+
result = this.(PathTypeRepr).getPath().(PathTypeReprOrPath).getTypeReprArgument(i)
195+
}
196+
}
197+
198+
private class TypeReprTypeReprOrPath extends TypeReprOrPath, TypeRepr { }
199+
200+
private class PathTypeReprOrPath extends TypeReprOrPath, Path { }
201+
202+
private TypeReprOrPath getTypeReprAt(TypeReprOrPath t, TypePath path) {
203+
path.isEmpty() and
204+
result = t
205+
or
206+
exists(int i, TypeReprOrPath arg, TypePath suffix |
207+
arg = t.getTypeReprArgument(i) and
208+
result = getTypeReprAt(arg, suffix) and
209+
path = typePath(i).append(suffix)
210+
)
211+
}
212+
213+
private Type resolveTypeRepr(TypeReprOrPath t) {
214+
t instanceof ArrayTypeRepr and
215+
result = TArrayType()
216+
or
217+
t instanceof RefTypeRepr and
218+
result = TRefType()
219+
or
220+
t =
221+
any(Path path |
222+
result = TStruct(resolvePath(path))
223+
or
224+
result = TEnum(resolvePath(path))
225+
or
226+
result = TTypeParameter(resolvePath(path))
227+
)
228+
or
229+
result = resolveTypeRepr(t.(PathTypeRepr).getPath())
230+
}
231+
232+
private Type resolveTypeAt(TypeReprOrPath t, TypePath path) {
233+
result = resolveTypeRepr(getTypeReprAt(t, path))
234+
}
235+
236+
private Type resolveVariableType(AstNode n, TypePath path) {
237+
exists(TypeRepr t | letStmtTyped(_, n, t) or paramTyped(_, n, t) |
238+
result = resolveTypeAt(t, path)
239+
)
240+
or
241+
exists(Variable v |
242+
result = resolveType(v.getPat(), path) and
243+
n = v.getAnAccess()
244+
)
245+
}
246+
247+
pragma[nomagic]
248+
private Type resolveTargetTyped(AstNode n, TypePath path) {
249+
isTargetTyped(n) and
250+
exists(LetStmt let |
251+
let.getPat() = n and
252+
result = resolveType(let.getInitializer(), path)
253+
)
254+
}
255+
256+
private Type resolveRecordExprType(RecordExpr re, TypePath path) {
257+
result = resolveTypeAt(re.getPath(), path)
258+
}
259+
260+
pragma[nomagic]
261+
private Type resolveCallExprType(CallExpr ce, TypePath path) {
262+
exists(ItemNode i | i = resolvePath(ce.getFunction().(PathExpr).getPath()) |
263+
result = resolveTypeAt(i.(Function).getRetType().getTypeRepr(), path)
264+
or
265+
exists(Enum e |
266+
i = e.getVariantList().getAVariant() and
267+
result = TEnum(e) and
268+
path.isEmpty()
269+
)
270+
)
271+
}
272+
273+
pragma[nomagic]
274+
Function resolveMethodCallExpr(MethodCallExpr mce) {
275+
exists(Type t |
276+
t = resolveType(mce.getReceiver()) and
277+
if t = TRefType()
278+
then
279+
// for reference types, lookup the method in the type being referenced
280+
result = resolveType(mce.getReceiver(), "0").getMethod(mce.getNameRef().getText())
281+
else result = t.getMethod(mce.getNameRef().getText())
282+
)
283+
}
284+
285+
pragma[nomagic]
286+
private Type resolveMethodCallExprType(MethodCallExpr mce, TypePath path) {
287+
exists(Function f |
288+
f = resolveMethodCallExpr(mce) and
289+
result = resolveTypeAt(f.getRetType().getTypeRepr(), path)
290+
)
291+
}
292+
293+
pragma[nomagic]
294+
private Type resolveRefExprType(RefExpr re, TypePath path) {
295+
exists(re) and
296+
path.isEmpty() and
297+
result = TRefType()
298+
or
299+
exists(TypePath suffix |
300+
result = resolveType(re.getExpr(), suffix) and
301+
path = typePath(0).append(suffix)
302+
)
303+
}
304+
305+
Type resolveType(AstNode n, TypePath path) {
306+
result = resolveVariableType(n, path)
307+
or
308+
result = resolveTargetTyped(n, path)
309+
or
310+
result = resolveRecordExprType(n, path)
311+
or
312+
result = resolveCallExprType(n, path)
313+
or
314+
result = resolveMethodCallExprType(n, path)
315+
or
316+
result = resolveRefExprType(n, path)
317+
}
318+
319+
Type resolveType(AstNode n) { result = resolveType(n, "") }

0 commit comments

Comments
 (0)