Skip to content

Commit 03a9a16

Browse files
committed
Rust: Add type inference for tuples
1 parent 21c030f commit 03a9a16

File tree

6 files changed

+571
-39
lines changed

6 files changed

+571
-39
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ private import codeql.rust.elements.internal.generated.Synth
99

1010
cached
1111
newtype TType =
12-
TUnit() or
13-
TStruct(Struct s) { Stages::TypeInferenceStage::ref() } or
12+
TTuple(int arity) {
13+
exists(any(TupleTypeRepr t).getField(arity)) and Stages::TypeInferenceStage::ref()
14+
} or
15+
TStruct(Struct s) or
1416
TEnum(Enum e) or
1517
TTrait(Trait t) or
1618
TArrayType() or // todo: add size?
1719
TRefType() or // todo: add mut?
1820
TImplTraitType(ImplTraitTypeRepr impl) or
1921
TSliceType() or
22+
TTupleTypeParameter(int i) { exists(TTuple(i)) } or
2023
TTypeParamTypeParameter(TypeParam t) or
2124
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2225
TArrayTypeParameter() or
@@ -56,8 +59,8 @@ abstract class Type extends TType {
5659
}
5760

5861
/** The unit type `()`. */
59-
class UnitType extends Type, TUnit {
60-
UnitType() { this = TUnit() }
62+
class UnitType extends Type, TTuple {
63+
UnitType() { this = TTuple(0) }
6164

6265
override StructField getStructField(string name) { none() }
6366

@@ -70,6 +73,25 @@ class UnitType extends Type, TUnit {
7073
override Location getLocation() { result instanceof EmptyLocation }
7174
}
7275

76+
/** A tuple type `(T, ...)`. */
77+
class TupleType extends Type, TTuple {
78+
private int arity;
79+
80+
TupleType() { this = TTuple(arity) and arity > 0 }
81+
82+
override StructField getStructField(string name) { none() }
83+
84+
override TupleField getTupleField(int i) { none() }
85+
86+
override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(i) and i < arity }
87+
88+
int getArity() { result = arity }
89+
90+
override string toString() { result = "(T_" + arity + ")" }
91+
92+
override Location getLocation() { result instanceof EmptyLocation }
93+
}
94+
7395
abstract private class StructOrEnumType extends Type {
7496
abstract ItemNode asItemNode();
7597
}
@@ -329,6 +351,21 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
329351
override Location getLocation() { result = typeAlias.getLocation() }
330352
}
331353

354+
/**
355+
* A tuple type parameter. For instance the `T` in `(T, U)`.
356+
*
357+
* Since tuples are structural their parameters can be represented simply as
358+
* their positional index.
359+
*/
360+
class TupleTypeParameter extends TypeParameter, TTupleTypeParameter {
361+
override string toString() { result = this.getIndex().toString() }
362+
363+
override Location getLocation() { result instanceof EmptyLocation }
364+
365+
/** Gets the index of this tuple type parameter. */
366+
int getIndex() { this = TTupleTypeParameter(result) }
367+
}
368+
332369
/** An implicit array type parameter. */
333370
class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
334371
override string toString() { result = "[T;...]" }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ private module Input1 implements InputSig1<Location> {
103103
node = tp0.(SelfTypeParameter).getTrait() or
104104
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
105105
)
106+
or
107+
kind = 2 and
108+
id = tp0.(TupleTypeParameter).getIndex()
106109
|
107110
tp0 order by kind, id
108111
)
@@ -229,7 +232,7 @@ private Type inferLogicalOperationType(AstNode n, TypePath path) {
229232
private Type inferAssignmentOperationType(AstNode n, TypePath path) {
230233
n instanceof AssignmentOperation and
231234
path.isEmpty() and
232-
result = TUnit()
235+
result instanceof UnitType
233236
}
234237

235238
pragma[nomagic]
@@ -321,6 +324,14 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
321324
prefix1.isEmpty() and
322325
prefix2 = TypePath::singleton(TRefTypeParameter())
323326
or
327+
exists(int i |
328+
prefix1.isEmpty() and
329+
prefix2 = TypePath::singleton(TTupleTypeParameter(i))
330+
|
331+
n1 = n2.(TupleExpr).getField(i) or
332+
n1 = n2.(TuplePat).getField(i)
333+
)
334+
or
324335
exists(BlockExpr be |
325336
n1 = be and
326337
n2 = be.getStmtList().getTailExpr() and
@@ -534,6 +545,12 @@ private Type inferStructExprType(AstNode n, TypePath path) {
534545
)
535546
}
536547

548+
pragma[nomagic]
549+
private Type inferTupleExprRootType(TupleExpr te) {
550+
// `typeEquality` handles the non-root case
551+
result = TTuple(te.getNumberOfFields())
552+
}
553+
537554
pragma[nomagic]
538555
private Type inferPathExprType(PathExpr pe, TypePath path) {
539556
// nullary struct/variant constructors
@@ -1055,6 +1072,31 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
10551072
)
10561073
}
10571074

1075+
pragma[nomagic]
1076+
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
1077+
exists(int i, TypePath path0 |
1078+
fe.getIdentifier().getText() = i.toString() and
1079+
result = inferType(fe.getContainer(), path0) and
1080+
path0.isCons(TTupleTypeParameter(i), path) and
1081+
fe.getIdentifier().getText() = i.toString()
1082+
)
1083+
}
1084+
1085+
/** Infers the type of `t` in `t.n` when `t` is a tuple. */
1086+
private Type inferTupleContainerExprType(Expr e, TypePath path) {
1087+
// NOTE: For a field expression `t.n` where `n` is a number `t` might both be
1088+
// a tuple struct or a tuple. It is only correct to let type information flow
1089+
// from `t.n` to tuple type parameters of `t` in the latter case. Hence we
1090+
// include the condition that the root type of `t` must be a tuple type.
1091+
exists(int i, TypePath path0, FieldExpr fe |
1092+
e = fe.getContainer() and
1093+
fe.getIdentifier().getText() = i.toString() and
1094+
inferType(fe.getContainer()) instanceof TupleType and
1095+
result = inferType(fe, path0) and
1096+
path = TypePath::cons(TTupleTypeParameter(i), path0)
1097+
)
1098+
}
1099+
10581100
/** Gets the root type of the reference node `ref`. */
10591101
pragma[nomagic]
10601102
private Type inferRefNodeType(AstNode ref) {
@@ -1943,12 +1985,19 @@ private module Cached {
19431985
or
19441986
result = inferStructExprType(n, path)
19451987
or
1988+
result = inferTupleExprRootType(n) and
1989+
path.isEmpty()
1990+
or
19461991
result = inferPathExprType(n, path)
19471992
or
19481993
result = inferCallExprBaseType(n, path)
19491994
or
19501995
result = inferFieldExprType(n, path)
19511996
or
1997+
result = inferTupleIndexExprType(n, path)
1998+
or
1999+
result = inferTupleContainerExprType(n, path)
2000+
or
19522001
result = inferRefNodeType(n) and
19532002
path.isEmpty()
19542003
or

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ abstract class TypeMention extends AstNode {
1414
final Type resolveType() { result = this.resolveTypeAt(TypePath::nil()) }
1515
}
1616

17+
class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
18+
override Type resolveTypeAt(TypePath path) {
19+
path.isEmpty() and
20+
result = TTuple(super.getNumberOfFields())
21+
or
22+
exists(TypePath suffix, int i |
23+
result = super.getField(i).(TypeMention).resolveTypeAt(suffix) and
24+
path = TypePath::cons(TTupleTypeParameter(i), suffix)
25+
)
26+
}
27+
}
28+
1729
class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
1830
override Type resolveTypeAt(TypePath path) {
1931
path.isEmpty() and

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,27 +2334,27 @@ mod tuples {
23342334
}
23352335

23362336
pub fn f() {
2337-
let a = S1::get_pair(); // $ target=get_pair MISSING: type=a:(T_2)
2338-
let mut b = S1::get_pair(); // $ target=get_pair MISSING: type=b:(T_2)
2339-
let (c, d) = S1::get_pair(); // $ target=get_pair MISSING: type=c:S1 type=d:S1
2340-
let (mut e, f) = S1::get_pair(); // $ target=get_pair MISSING: type=e:S1 type=f:S1
2341-
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair MISSING: type=g:S1 type=h:S1
2342-
2343-
a.0.foo(); // $ MISSING: target=foo
2344-
b.1.foo(); // $ MISSING: target=foo
2345-
c.foo(); // $ MISSING: target=foo
2346-
d.foo(); // $ MISSING: target=foo
2347-
e.foo(); // $ MISSING: target=foo
2348-
f.foo(); // $ MISSING: target=foo
2349-
g.foo(); // $ MISSING: target=foo
2350-
h.foo(); // $ MISSING: target=foo
2337+
let a = S1::get_pair(); // $ target=get_pair type=a:(T_2)
2338+
let mut b = S1::get_pair(); // $ target=get_pair type=b:(T_2)
2339+
let (c, d) = S1::get_pair(); // $ target=get_pair type=c:S1 type=d:S1
2340+
let (mut e, f) = S1::get_pair(); // $ target=get_pair type=e:S1 type=f:S1
2341+
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair type=g:S1 type=h:S1
2342+
2343+
a.0.foo(); // $ target=foo
2344+
b.1.foo(); // $ target=foo
2345+
c.foo(); // $ target=foo
2346+
d.foo(); // $ target=foo
2347+
e.foo(); // $ target=foo
2348+
f.foo(); // $ target=foo
2349+
g.foo(); // $ target=foo
2350+
h.foo(); // $ target=foo
23512351

23522352
// Here type information must flow from `pair.0` and `pair.1` into
23532353
// `pair` and from `(a, b)` into `a` and `b` in order for the types of
23542354
// `a` and `b` to be inferred.
2355-
let a = Default::default(); // $ MISSING: target=default type=a:i64
2356-
let b = Default::default(); // $ MISSING: target=default MISSING: type=b:bool
2357-
let pair = (a, b); // $ MISSING: type=pair:0.i64 type=pair:1.bool
2355+
let a = Default::default(); // $ target=default type=a:i64
2356+
let b = Default::default(); // $ target=default type=b:bool
2357+
let pair = (a, b); // $ type=pair:0.i64 type=pair:1.bool
23582358
let i: i64 = pair.0;
23592359
let j: bool = pair.1;
23602360
}

rust/ql/test/library-tests/type-inference/pattern_matching.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -446,21 +446,21 @@ pub fn tuple_patterns() {
446446
// TuplePat - Tuple patterns
447447
match tuple {
448448
(1, 2, 3.0) => {
449-
let exact_tuple = tuple; // $ MISSING: type=exact_tuple:?
449+
let exact_tuple = tuple; // $ type=exact_tuple:(T_3)
450450
println!("Exact tuple: {:?}", exact_tuple);
451451
}
452452
(a, b, c) => {
453-
let first_elem = a; // $ MISSING: type=first_elem:i32
454-
let second_elem = b; // $ MISSING: type=second_elem:i64
455-
let third_elem = c; // $ MISSING: type=third_elem:f32
453+
let first_elem = a; // $ type=first_elem:i32
454+
let second_elem = b; // $ type=second_elem:i64
455+
let third_elem = c; // $ type=third_elem:f32
456456
println!("Tuple: ({}, {}, {})", first_elem, second_elem, third_elem);
457457
}
458458
}
459459

460460
// With rest pattern
461461
match tuple {
462462
(first, ..) => {
463-
let tuple_first = first; // $ MISSING: type=tuple_first:i32
463+
let tuple_first = first; // $ type=tuple_first:i32
464464
println!("First element: {}", tuple_first);
465465
}
466466
}
@@ -469,7 +469,7 @@ pub fn tuple_patterns() {
469469
let unit = ();
470470
match unit {
471471
() => {
472-
let unit_value = unit; // $ MISSING: type=unit_value:?
472+
let unit_value = unit; // $ type=unit_value:()
473473
println!("Unit value: {:?}", unit_value);
474474
}
475475
}
@@ -478,7 +478,7 @@ pub fn tuple_patterns() {
478478
let single = (42i32,);
479479
match single {
480480
(x,) => {
481-
let single_elem = x; // $ MISSING: type=single_elem:i32
481+
let single_elem = x; // $ type=single_elem:i32
482482
println!("Single element tuple: {}", single_elem);
483483
}
484484
}
@@ -499,8 +499,8 @@ pub fn parenthesized_patterns() {
499499
let tuple = (1i32, 2i32);
500500
match tuple {
501501
(x, (y)) => {
502-
let paren_x = x; // $ MISSING: type=paren_x:i32
503-
let paren_y = y; // $ MISSING: type=paren_y:i32
502+
let paren_x = x; // $ type=paren_x:i32
503+
let paren_y = y; // $ type=paren_y:i32
504504
println!("Parenthesized in tuple: {}, {}", paren_x, paren_y);
505505
}
506506
}
@@ -630,7 +630,7 @@ pub fn rest_patterns() {
630630
// RestPat - Rest patterns (..)
631631
match tuple {
632632
(first, ..) => {
633-
let rest_first = first; // $ MISSING: type=rest_first:i32
633+
let rest_first = first; // $ type=rest_first:i32
634634
println!("First with rest: {}", rest_first);
635635
}
636636
}
@@ -644,7 +644,7 @@ pub fn rest_patterns() {
644644

645645
match tuple {
646646
(first, .., last) => {
647-
let rest_start = first; // $ MISSING: type=rest_start:i32
647+
let rest_start = first; // $ type=rest_start:i32
648648
let rest_end = last; // $ MISSING: type=rest_end:u8
649649
println!("First and last: {}, {}", rest_start, rest_end);
650650
}
@@ -719,9 +719,9 @@ pub fn patterns_in_let_statements() {
719719

720720
let tuple = (1i32, 2i64, 3.0f32);
721721
let (a, b, c) = tuple; // TuplePat in let
722-
let let_a = a; // $ MISSING: type=let_a:i32
723-
let let_b = b; // $ MISSING: type=let_b:i64
724-
let let_c = c; // $ MISSING: type=let_c:f32
722+
let let_a = a; // $ type=let_a:i32
723+
let let_b = b; // $ type=let_b:i64
724+
let let_c = c; // $ type=let_c:f32
725725

726726
let array = [1i32, 2, 3, 4, 5];
727727
let [first, .., last] = array; // SlicePat in let
@@ -759,8 +759,8 @@ pub fn patterns_in_function_parameters() {
759759
}
760760

761761
fn extract_tuple((first, _, third): (i32, f64, bool)) -> (i32, bool) {
762-
let param_first = first; // $ MISSING: type=param_first:i32
763-
let param_third = third; // $ MISSING: type=param_third:bool
762+
let param_first = first; // $ type=param_first:i32
763+
let param_third = third; // $ type=param_third:bool
764764
(param_first, param_third)
765765
}
766766

@@ -772,7 +772,7 @@ pub fn patterns_in_function_parameters() {
772772
let red = extract_color(color); // $ target=extract_color type=red:u8
773773

774774
let tuple = (42i32, 3.14f64, true);
775-
let tuple_extracted = extract_tuple(tuple); // $ target=extract_tuple MISSING: type=tuple_extracted:?
775+
let tuple_extracted = extract_tuple(tuple); // $ target=extract_tuple type=tuple_extracted:0.i32 type=tuple_extracted:1.bool
776776
}
777777

778778
#[rustfmt::skip]

0 commit comments

Comments
 (0)