Skip to content

Rust: Unify type inference for tuple indexing expressions #20182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class TupleType extends Type, TTuple {
}

/** The unit type `()`. */
class UnitType extends TupleType, TTuple {
class UnitType extends TupleType {
UnitType() { this = TTuple(0) }

override string toString() { result = "()" }
Expand Down
169 changes: 99 additions & 70 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,36 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
)
}

pragma[inline]
private Type inferRootTypeDeref(AstNode n) {
result = inferType(n) and
result != TRefType()
or
// for reference types, lookup members in the type being referenced
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
}

pragma[nomagic]
private Type getFieldExprLookupType(FieldExpr fe, string name) {
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
}

pragma[nomagic]
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
exists(string name |
result = getFieldExprLookupType(fe, name) and
pos = name.toInt()
)
}

pragma[nomagic]
private TupleTypeParameter resolveTupleTypeFieldExpr(FieldExpr fe) {
exists(int arity, int i |
TTuple(arity) = getTupleFieldExprLookupType(fe, i) and
result = TTupleTypeParameter(arity, i)
)
}

/**
* A matching configuration for resolving types of field expressions
* like `x.field`.
Expand All @@ -1158,15 +1188,30 @@ private module FieldExprMatchingInput implements MatchingInputSig {
}
}

abstract class Declaration extends AstNode {
private newtype TDeclaration =
TStructFieldDecl(StructField sf) or
TTupleFieldDecl(TupleField tf) or
TTupleTypeParameterDecl(TupleTypeParameter ttp)

abstract class Declaration extends TDeclaration {
TypeParameter getTypeParameter(TypeParameterPosition ppos) { none() }

abstract Type getDeclaredType(DeclarationPosition dpos, TypePath path);

abstract string toString();

abstract Location getLocation();
}
Copy link
Preview

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstract Declaration class has a concrete implementation of getTypeParameter that returns none() for all subclasses. This suggests that either this method should be abstract (forcing subclasses to implement it) or it should have a more meaningful default implementation. Consider making this method abstract if different declaration types should handle type parameters differently.

Copilot uses AI. Check for mistakes.


abstract private class StructOrTupleFieldDecl extends Declaration {
abstract AstNode getAstNode();

abstract TypeRepr getTypeRepr();

Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
dpos.isSelf() and
// no case for variants as those can only be destructured using pattern matching
exists(Struct s | s.getStructField(_) = this or s.getTupleField(_) = this |
exists(Struct s | this.getAstNode() = [s.getStructField(_).(AstNode), s.getTupleField(_)] |
result = TStruct(s) and
path.isEmpty()
or
Expand All @@ -1177,14 +1222,55 @@ private module FieldExprMatchingInput implements MatchingInputSig {
dpos.isField() and
result = this.getTypeRepr().(TypeMention).resolveTypeAt(path)
}

override string toString() { result = this.getAstNode().toString() }

override Location getLocation() { result = this.getAstNode().getLocation() }
}

private class StructFieldDecl extends StructOrTupleFieldDecl, TStructFieldDecl {
private StructField sf;

StructFieldDecl() { this = TStructFieldDecl(sf) }

override AstNode getAstNode() { result = sf }

override TypeRepr getTypeRepr() { result = sf.getTypeRepr() }
}

private class StructFieldDecl extends Declaration instanceof StructField {
override TypeRepr getTypeRepr() { result = StructField.super.getTypeRepr() }
private class TupleFieldDecl extends StructOrTupleFieldDecl, TTupleFieldDecl {
private TupleField tf;

TupleFieldDecl() { this = TTupleFieldDecl(tf) }

override AstNode getAstNode() { result = tf }

override TypeRepr getTypeRepr() { result = tf.getTypeRepr() }
}

private class TupleFieldDecl extends Declaration instanceof TupleField {
override TypeRepr getTypeRepr() { result = TupleField.super.getTypeRepr() }
private class TupleTypeParameterDecl extends Declaration, TTupleTypeParameterDecl {
private TupleTypeParameter ttp;

TupleTypeParameterDecl() { this = TTupleTypeParameterDecl(ttp) }

override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
dpos.isSelf() and
(
result = ttp.getTupleType() and
path.isEmpty()
or
result = ttp and
path = TypePath::singleton(ttp)
)
or
dpos.isField() and
result = ttp and
path.isEmpty()
}

override string toString() { result = ttp.toString() }

override Location getLocation() { result = ttp.getLocation() }
}

class AccessPosition = DeclarationPosition;
Expand All @@ -1206,7 +1292,12 @@ private module FieldExprMatchingInput implements MatchingInputSig {

Declaration getTarget() {
// mutual recursion; resolving fields requires resolving types and vice versa
result = [resolveStructFieldExpr(this).(AstNode), resolveTupleFieldExpr(this)]
result =
[
TStructFieldDecl(resolveStructFieldExpr(this)).(TDeclaration),
TTupleFieldDecl(resolveTupleFieldExpr(this)),
TTupleTypeParameterDecl(resolveTupleTypeFieldExpr(this))
]
}
}

Expand Down Expand Up @@ -1266,42 +1357,6 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
)
}

pragma[nomagic]
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
exists(int i, TypePath path0 |
fe.getIdentifier().getText() = i.toString() and
result = inferType(fe.getContainer(), path0) and
path0.isCons(TTupleTypeParameter(_, i), path) and
fe.getIdentifier().getText() = i.toString()
)
}

/** Infers the type of `t` in `t.n` when `t` is a tuple. */
private Type inferTupleContainerExprType(Expr e, TypePath path) {
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
// tuple as in:
// ```rust
// let t = (Default::default(), 2);
// let s: String = t.0;
// ```
// But it could also be a tuple struct as in:
// ```rust
// struct T(String, u32);
// let t = T(Default::default(), 2);
// let s: String = t.0;
// ```
// We need type information to flow from `t.n` to tuple type parameters of `t`
// in the former case but not the latter case. Hence we include the condition
// that the root type of `t` must be a tuple type.
exists(int i, TypePath path0, FieldExpr fe, int arity |
e = fe.getContainer() and
fe.getIdentifier().getText() = i.toString() and
arity = inferType(fe.getContainer()).(TupleType).getArity() and
result = inferType(fe, path0) and
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
)
}

/** Gets the root type of the reference node `ref`. */
pragma[nomagic]
private Type inferRefNodeType(AstNode ref) {
Expand Down Expand Up @@ -2230,20 +2285,6 @@ private module Cached {
result = resolveFunctionCallTarget(call)
}

pragma[inline]
private Type inferRootTypeDeref(AstNode n) {
result = inferType(n) and
result != TRefType()
or
// for reference types, lookup members in the type being referenced
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
}

pragma[nomagic]
private Type getFieldExprLookupType(FieldExpr fe, string name) {
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
}

/**
* Gets the struct field that the field expression `fe` resolves to, if any.
*/
Expand All @@ -2252,14 +2293,6 @@ private module Cached {
exists(string name | result = getFieldExprLookupType(fe, name).getStructField(name))
}

pragma[nomagic]
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
exists(string name |
result = getFieldExprLookupType(fe, name) and
pos = name.toInt()
)
}

/**
* Gets the tuple field that the field expression `fe` resolves to, if any.
*/
Expand Down Expand Up @@ -2341,10 +2374,6 @@ private module Cached {
or
result = inferFieldExprType(n, path)
or
result = inferTupleIndexExprType(n, path)
or
result = inferTupleContainerExprType(n, path)
or
result = inferRefNodeType(n) and
path.isEmpty()
or
Expand Down
4 changes: 4 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2444,6 +2444,7 @@ mod explicit_type_args {
}

mod tuples {
#[derive(Debug, Clone, Copy)]
struct S1 {}

impl S1 {
Expand Down Expand Up @@ -2484,6 +2485,9 @@ mod tuples {
_ => print!("expected"),
}
let x = pair.0; // $ type=x:i32

let y = &S1::get_pair(); // $ target=get_pair
y.0.foo(); // $ target=foo
}
}

Expand Down
Loading