Skip to content

Commit a508089

Browse files
committed
Rust: Improvements to tuple type inference based on PR feedback
1 parent 8858f21 commit a508089

File tree

8 files changed

+189
-128
lines changed

8 files changed

+189
-128
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/lib/codeql/rust/elements/internal/TuplePatImpl.qll

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `TuplePat`.
43
*
@@ -12,12 +11,25 @@ private import codeql.rust.elements.internal.generated.TuplePat
1211
* be referenced directly.
1312
*/
1413
module Impl {
14+
private import rust
15+
16+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1517
/**
1618
* A tuple pattern. For example:
1719
* ```rust
1820
* let (x, y) = (1, 2);
1921
* let (a, b, .., z) = (1, 2, 3, 4, 5);
2022
* ```
2123
*/
22-
class TuplePat extends Generated::TuplePat { }
24+
class TuplePat extends Generated::TuplePat {
25+
/**
26+
* Gets the arity of the tuple matched by this pattern, if any.
27+
*
28+
* This is the number of fields in the tuple pattern if and only if the
29+
* pattern does not contain a `..` pattern.
30+
*/
31+
int getTupleArity() {
32+
result = this.getNumberOfFields() and not this.getAField() instanceof RestPat
33+
}
34+
}
2335
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ private import codeql.rust.elements.internal.generated.Synth
1010
cached
1111
newtype TType =
1212
TTuple(int arity) {
13-
arity = any(TupleTypeRepr t).getNumberOfFields() and
13+
arity =
14+
[
15+
any(TupleTypeRepr t).getNumberOfFields(),
16+
any(TupleExpr e).getNumberOfFields(),
17+
any(TuplePat p).getNumberOfFields()
18+
] and
1419
Stages::TypeInferenceStage::ref()
1520
} or
1621
TStruct(Struct s) or
@@ -59,40 +64,33 @@ abstract class Type extends TType {
5964
abstract Location getLocation();
6065
}
6166

62-
/** The unit type `()`. */
63-
class UnitType extends Type, TTuple {
64-
UnitType() { this = TTuple(0) }
65-
66-
override StructField getStructField(string name) { none() }
67-
68-
override TupleField getTupleField(int i) { none() }
69-
70-
override TypeParameter getTypeParameter(int i) { none() }
71-
72-
override string toString() { result = "()" }
73-
74-
override Location getLocation() { result instanceof EmptyLocation }
75-
}
76-
7767
/** A tuple type `(T, ...)`. */
7868
class TupleType extends Type, TTuple {
7969
private int arity;
8070

81-
TupleType() { this = TTuple(arity) and arity > 0 }
71+
TupleType() { this = TTuple(arity) }
8272

8373
override StructField getStructField(string name) { none() }
8474

8575
override TupleField getTupleField(int i) { none() }
8676

8777
override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(arity, i) }
8878

79+
/** Gets the arity of this tuple type. */
8980
int getArity() { result = arity }
9081

9182
override string toString() { result = "(T_" + arity + ")" }
9283

9384
override Location getLocation() { result instanceof EmptyLocation }
9485
}
9586

87+
/** The unit type `()`. */
88+
class UnitType extends TupleType, TTuple {
89+
UnitType() { this = TTuple(0) }
90+
91+
override string toString() { result = "()" }
92+
}
93+
9694
abstract private class StructOrEnumType extends Type {
9795
abstract ItemNode asItemNode();
9896
}
@@ -355,8 +353,9 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
355353
/**
356354
* A tuple type parameter. For instance the `T` in `(T, U)`.
357355
*
358-
* Since tuples are structural their parameters can be represented simply as
359-
* their positional index.
356+
* Since tuples are structural their type parameters can be represented as their
357+
* positional index. The type inference library requires that type parameters
358+
* belong to a single type, so we also include the arity of the tuple type.
360359
*/
361360
class TupleTypeParameter extends TypeParameter, TTupleTypeParameter {
362361
private int arity;
@@ -371,8 +370,8 @@ class TupleTypeParameter extends TypeParameter, TTupleTypeParameter {
371370
/** Gets the index of this tuple type parameter. */
372371
int getIndex() { result = index }
373372

374-
/** Gets the arity of this tuple type parameter. */
375-
int getArity() { result = arity }
373+
/** Gets the tuple type that corresponds to this tuple type parameter. */
374+
TupleType getTupleType() { result = TTuple(arity) }
376375
}
377376

378377
/** An implicit array type parameter. */

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private module Input1 implements InputSig1<Location> {
108108
maxArity = max(int i | i = any(TupleType tt).getArity()) and
109109
tp0 = ttp and
110110
kind = 2 and
111-
id = ttp.getArity() * maxArity + ttp.getIndex()
111+
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
112112
)
113113
|
114114
tp0 order by kind, id
@@ -335,7 +335,7 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
335335
arity = n2.(TupleExpr).getNumberOfFields() and
336336
n1 = n2.(TupleExpr).getField(i)
337337
or
338-
arity = n2.(TuplePat).getNumberOfFields() and
338+
arity = n2.(TuplePat).getTupleArity() and
339339
n1 = n2.(TuplePat).getField(i)
340340
)
341341
or
@@ -553,9 +553,9 @@ private Type inferStructExprType(AstNode n, TypePath path) {
553553
}
554554

555555
pragma[nomagic]
556-
private Type inferTupleExprRootType(TupleExpr te) {
557-
// `typeEquality` handles the non-root case
558-
result = TTuple(te.getNumberOfFields())
556+
private Type inferTupleRootType(AstNode n) {
557+
// `typeEquality` handles the non-root cases
558+
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
559559
}
560560

561561
pragma[nomagic]
@@ -1091,16 +1091,27 @@ private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
10911091

10921092
/** Infers the type of `t` in `t.n` when `t` is a tuple. */
10931093
private Type inferTupleContainerExprType(Expr e, TypePath path) {
1094-
// NOTE: For a field expression `t.n` where `n` is a number `t` might both be
1095-
// a tuple struct or a tuple. It is only correct to let type information flow
1096-
// from `t.n` to tuple type parameters of `t` in the latter case. Hence we
1097-
// include the condition that the root type of `t` must be a tuple type.
1094+
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
1095+
// tuple as in:
1096+
// ```rust
1097+
// let t = (Default::default(), 2);
1098+
// let s: String = t.0;
1099+
// ```
1100+
// But it could also be a tuple struct as in:
1101+
// ```rust
1102+
// struct T(String, u32);
1103+
// let t = T(Default::default(), 2);
1104+
// let s: String = t.0;
1105+
// ```
1106+
// We need type information to flow from `t.n` to tuple type parameters of `t`
1107+
// in the former case but not the latter case. Hence we include the condition
1108+
// that the root type of `t` must be a tuple type.
10981109
exists(int i, TypePath path0, FieldExpr fe, int arity |
10991110
e = fe.getContainer() and
11001111
fe.getIdentifier().getText() = i.toString() and
11011112
arity = inferType(fe.getContainer()).(TupleType).getArity() and
11021113
result = inferType(fe, path0) and
1103-
path = TypePath::cons(TTupleTypeParameter(arity, i), path0) // FIXME:
1114+
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
11041115
)
11051116
}
11061117

@@ -1992,7 +2003,7 @@ private module Cached {
19922003
or
19932004
result = inferStructExprType(n, path)
19942005
or
1995-
result = inferTupleExprRootType(n) and
2006+
result = inferTupleRootType(n) and
19962007
path.isEmpty()
19972008
or
19982009
result = inferPathExprType(n, path)

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,6 +2357,13 @@ mod tuples {
23572357
let pair = (a, b); // $ type=pair:0(2).i64 type=pair:1(2).bool
23582358
let i: i64 = pair.0;
23592359
let j: bool = pair.1;
2360+
2361+
let pair = [1, 1].into(); // $ type=pair:0(2).i32 MISSING: target=into
2362+
match pair {
2363+
(0,0) => print!("unexpected"),
2364+
_ => print!("expected"),
2365+
}
2366+
let x = pair.0; // $ type=x:i32
23602367
}
23612368
}
23622369

rust/ql/test/library-tests/type-inference/pattern_matching.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ pub fn complex_nested_patterns() {
704704
}
705705
// Catch-all with identifier pattern
706706
other => {
707-
let other_complex = other; // $ MISSING: type=other_complex:?
707+
let other_complex = other; // $ type=other_complex:0(2).Point type=other_complex:1(2).MyOption
708708
println!("Other complex data: {:?}", other_complex);
709709
}
710710
}
@@ -766,7 +766,7 @@ pub fn patterns_in_function_parameters() {
766766

767767
// Call the functions to use them
768768
let point = Point { x: 5, y: 10 };
769-
let extracted = extract_point(point); // $ target=extract_point MISSING: type=extracted:?
769+
let extracted = extract_point(point); // $ target=extract_point type=extracted:0(2).i32 type=extracted:1(2).i32
770770

771771
let color = Color(200, 100, 50);
772772
let red = extract_color(color); // $ target=extract_color type=red:u8

0 commit comments

Comments
 (0)