Skip to content

Commit 0d7a2f6

Browse files
committed
Rust: Unify type inference for tuple indexing expressions
1 parent b40118d commit 0d7a2f6

File tree

4 files changed

+102
-72
lines changed

4 files changed

+102
-72
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class TupleType extends Type, TTuple {
119119
}
120120

121121
/** The unit type `()`. */
122-
class UnitType extends TupleType, TTuple {
122+
class UnitType extends TupleType {
123123
UnitType() { this = TTuple(0) }
124124

125125
override string toString() { result = "()" }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 99 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,36 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
11351135
)
11361136
}
11371137

1138+
pragma[inline]
1139+
private Type inferRootTypeDeref(AstNode n) {
1140+
result = inferType(n) and
1141+
result != TRefType()
1142+
or
1143+
// for reference types, lookup members in the type being referenced
1144+
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
1145+
}
1146+
1147+
pragma[nomagic]
1148+
private Type getFieldExprLookupType(FieldExpr fe, string name) {
1149+
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
1150+
}
1151+
1152+
pragma[nomagic]
1153+
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
1154+
exists(string name |
1155+
result = getFieldExprLookupType(fe, name) and
1156+
pos = name.toInt()
1157+
)
1158+
}
1159+
1160+
pragma[nomagic]
1161+
private TupleTypeParameter resolveTupleTypeFieldExpr(FieldExpr fe) {
1162+
exists(int arity, int i |
1163+
TTuple(arity) = getTupleFieldExprLookupType(fe, i) and
1164+
result = TTupleTypeParameter(arity, i)
1165+
)
1166+
}
1167+
11381168
/**
11391169
* A matching configuration for resolving types of field expressions
11401170
* like `x.field`.
@@ -1158,15 +1188,30 @@ private module FieldExprMatchingInput implements MatchingInputSig {
11581188
}
11591189
}
11601190

1161-
abstract class Declaration extends AstNode {
1191+
private newtype TDeclaration =
1192+
TStructFieldDecl(StructField sf) or
1193+
TTupleFieldDecl(TupleField tf) or
1194+
TTupleTypeParameterDecl(TupleTypeParameter ttp)
1195+
1196+
abstract class Declaration extends TDeclaration {
11621197
TypeParameter getTypeParameter(TypeParameterPosition ppos) { none() }
11631198

1199+
abstract Type getDeclaredType(DeclarationPosition dpos, TypePath path);
1200+
1201+
abstract string toString();
1202+
1203+
abstract Location getLocation();
1204+
}
1205+
1206+
abstract private class StructOrTupleFieldDecl extends Declaration {
1207+
abstract AstNode getAstNode();
1208+
11641209
abstract TypeRepr getTypeRepr();
11651210

1166-
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
1211+
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
11671212
dpos.isSelf() and
11681213
// no case for variants as those can only be destructured using pattern matching
1169-
exists(Struct s | s.getStructField(_) = this or s.getTupleField(_) = this |
1214+
exists(Struct s | this.getAstNode() = [s.getStructField(_).(AstNode), s.getTupleField(_)] |
11701215
result = TStruct(s) and
11711216
path.isEmpty()
11721217
or
@@ -1177,14 +1222,55 @@ private module FieldExprMatchingInput implements MatchingInputSig {
11771222
dpos.isField() and
11781223
result = this.getTypeRepr().(TypeMention).resolveTypeAt(path)
11791224
}
1225+
1226+
override string toString() { result = this.getAstNode().toString() }
1227+
1228+
override Location getLocation() { result = this.getAstNode().getLocation() }
1229+
}
1230+
1231+
private class StructFieldDecl extends StructOrTupleFieldDecl, TStructFieldDecl {
1232+
private StructField sf;
1233+
1234+
StructFieldDecl() { this = TStructFieldDecl(sf) }
1235+
1236+
override AstNode getAstNode() { result = sf }
1237+
1238+
override TypeRepr getTypeRepr() { result = sf.getTypeRepr() }
11801239
}
11811240

1182-
private class StructFieldDecl extends Declaration instanceof StructField {
1183-
override TypeRepr getTypeRepr() { result = StructField.super.getTypeRepr() }
1241+
private class TupleFieldDecl extends StructOrTupleFieldDecl, TTupleFieldDecl {
1242+
private TupleField tf;
1243+
1244+
TupleFieldDecl() { this = TTupleFieldDecl(tf) }
1245+
1246+
override AstNode getAstNode() { result = tf }
1247+
1248+
override TypeRepr getTypeRepr() { result = tf.getTypeRepr() }
11841249
}
11851250

1186-
private class TupleFieldDecl extends Declaration instanceof TupleField {
1187-
override TypeRepr getTypeRepr() { result = TupleField.super.getTypeRepr() }
1251+
private class TupleTypeParameterDecl extends Declaration, TTupleTypeParameterDecl {
1252+
private TupleTypeParameter ttp;
1253+
1254+
TupleTypeParameterDecl() { this = TTupleTypeParameterDecl(ttp) }
1255+
1256+
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
1257+
dpos.isSelf() and
1258+
(
1259+
result = ttp.getTupleType() and
1260+
path.isEmpty()
1261+
or
1262+
result = ttp and
1263+
path = TypePath::singleton(ttp)
1264+
)
1265+
or
1266+
dpos.isField() and
1267+
result = ttp and
1268+
path.isEmpty()
1269+
}
1270+
1271+
override string toString() { result = ttp.toString() }
1272+
1273+
override Location getLocation() { result = ttp.getLocation() }
11881274
}
11891275

11901276
class AccessPosition = DeclarationPosition;
@@ -1206,7 +1292,12 @@ private module FieldExprMatchingInput implements MatchingInputSig {
12061292

12071293
Declaration getTarget() {
12081294
// mutual recursion; resolving fields requires resolving types and vice versa
1209-
result = [resolveStructFieldExpr(this).(AstNode), resolveTupleFieldExpr(this)]
1295+
result =
1296+
[
1297+
TStructFieldDecl(resolveStructFieldExpr(this)).(TDeclaration),
1298+
TTupleFieldDecl(resolveTupleFieldExpr(this)),
1299+
TTupleTypeParameterDecl(resolveTupleTypeFieldExpr(this))
1300+
]
12101301
}
12111302
}
12121303

@@ -1266,42 +1357,6 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
12661357
)
12671358
}
12681359

1269-
pragma[nomagic]
1270-
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
1271-
exists(int i, TypePath path0 |
1272-
fe.getIdentifier().getText() = i.toString() and
1273-
result = inferType(fe.getContainer(), path0) and
1274-
path0.isCons(TTupleTypeParameter(_, i), path) and
1275-
fe.getIdentifier().getText() = i.toString()
1276-
)
1277-
}
1278-
1279-
/** Infers the type of `t` in `t.n` when `t` is a tuple. */
1280-
private Type inferTupleContainerExprType(Expr e, TypePath path) {
1281-
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
1282-
// tuple as in:
1283-
// ```rust
1284-
// let t = (Default::default(), 2);
1285-
// let s: String = t.0;
1286-
// ```
1287-
// But it could also be a tuple struct as in:
1288-
// ```rust
1289-
// struct T(String, u32);
1290-
// let t = T(Default::default(), 2);
1291-
// let s: String = t.0;
1292-
// ```
1293-
// We need type information to flow from `t.n` to tuple type parameters of `t`
1294-
// in the former case but not the latter case. Hence we include the condition
1295-
// that the root type of `t` must be a tuple type.
1296-
exists(int i, TypePath path0, FieldExpr fe, int arity |
1297-
e = fe.getContainer() and
1298-
fe.getIdentifier().getText() = i.toString() and
1299-
arity = inferType(fe.getContainer()).(TupleType).getArity() and
1300-
result = inferType(fe, path0) and
1301-
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
1302-
)
1303-
}
1304-
13051360
/** Gets the root type of the reference node `ref`. */
13061361
pragma[nomagic]
13071362
private Type inferRefNodeType(AstNode ref) {
@@ -2230,20 +2285,6 @@ private module Cached {
22302285
result = resolveFunctionCallTarget(call)
22312286
}
22322287

2233-
pragma[inline]
2234-
private Type inferRootTypeDeref(AstNode n) {
2235-
result = inferType(n) and
2236-
result != TRefType()
2237-
or
2238-
// for reference types, lookup members in the type being referenced
2239-
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
2240-
}
2241-
2242-
pragma[nomagic]
2243-
private Type getFieldExprLookupType(FieldExpr fe, string name) {
2244-
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
2245-
}
2246-
22472288
/**
22482289
* Gets the struct field that the field expression `fe` resolves to, if any.
22492290
*/
@@ -2252,14 +2293,6 @@ private module Cached {
22522293
exists(string name | result = getFieldExprLookupType(fe, name).getStructField(name))
22532294
}
22542295

2255-
pragma[nomagic]
2256-
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
2257-
exists(string name |
2258-
result = getFieldExprLookupType(fe, name) and
2259-
pos = name.toInt()
2260-
)
2261-
}
2262-
22632296
/**
22642297
* Gets the tuple field that the field expression `fe` resolves to, if any.
22652298
*/
@@ -2341,10 +2374,6 @@ private module Cached {
23412374
or
23422375
result = inferFieldExprType(n, path)
23432376
or
2344-
result = inferTupleIndexExprType(n, path)
2345-
or
2346-
result = inferTupleContainerExprType(n, path)
2347-
or
23482377
result = inferRefNodeType(n) and
23492378
path.isEmpty()
23502379
or

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2487,7 +2487,7 @@ mod tuples {
24872487
let x = pair.0; // $ type=x:i32
24882488

24892489
let y = &S1::get_pair(); // $ target=get_pair
2490-
y.0.foo(); // $ MISSING: target=foo
2490+
y.0.foo(); // $ target=foo
24912491
}
24922492
}
24932493

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4856,6 +4856,7 @@ inferType
48564856
| main.rs:2490:9:2490:9 | y | &T | file://:0:0:0:0 | (T_2) |
48574857
| main.rs:2490:9:2490:9 | y | &T.0(2) | main.rs:2447:5:2448:16 | S1 |
48584858
| main.rs:2490:9:2490:9 | y | &T.1(2) | main.rs:2447:5:2448:16 | S1 |
4859+
| main.rs:2490:9:2490:11 | y.0 | | main.rs:2447:5:2448:16 | S1 |
48594860
| main.rs:2497:13:2497:23 | boxed_value | | {EXTERNAL LOCATION} | Box |
48604861
| main.rs:2497:13:2497:23 | boxed_value | A | {EXTERNAL LOCATION} | Global |
48614862
| main.rs:2497:13:2497:23 | boxed_value | T | {EXTERNAL LOCATION} | i32 |

0 commit comments

Comments
 (0)