Skip to content

Commit 06c8963

Browse files
committed
Shared: Infer types for type parameters with contraints
1 parent 831413b commit 06c8963

File tree

3 files changed

+85
-26
lines changed

3 files changed

+85
-26
lines changed

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ mod function_trait_bounds {
307307
let x2 = MyThing { a: S1 };
308308
let y2 = MyThing { a: S2 };
309309

310-
println!("{:?}", call_trait_m1(x2)); // missing
311-
println!("{:?}", call_trait_m1(y2)); // missing
310+
println!("{:?}", call_trait_m1(x2));
311+
println!("{:?}", call_trait_m1(y2));
312312

313313
let x3 = MyThing {
314314
a: MyThing { a: S1 },
@@ -317,8 +317,8 @@ mod function_trait_bounds {
317317
a: MyThing { a: S2 },
318318
};
319319

320-
println!("{:?}", call_trait_thing_m1(x3)); // missing
321-
println!("{:?}", call_trait_thing_m1(y3)); // missing
320+
println!("{:?}", call_trait_thing_m1(x3));
321+
println!("{:?}", call_trait_thing_m1(y3));
322322
}
323323
}
324324

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,10 @@ inferType
409409
| main.rs:308:18:308:34 | MyThing {...} | | main.rs:257:5:260:5 | struct MyThing |
410410
| main.rs:308:18:308:34 | MyThing {...} | T | main.rs:264:5:265:14 | struct S2 |
411411
| main.rs:308:31:308:32 | S2 | | main.rs:264:5:265:14 | struct S2 |
412+
| main.rs:310:26:310:42 | call_trait_m1(...) | | main.rs:262:5:263:14 | struct S1 |
412413
| main.rs:310:40:310:41 | x2 | | main.rs:257:5:260:5 | struct MyThing |
413414
| main.rs:310:40:310:41 | x2 | T | main.rs:262:5:263:14 | struct S1 |
415+
| main.rs:311:26:311:42 | call_trait_m1(...) | | main.rs:264:5:265:14 | struct S2 |
414416
| main.rs:311:40:311:41 | y2 | | main.rs:257:5:260:5 | struct MyThing |
415417
| main.rs:311:40:311:41 | y2 | T | main.rs:264:5:265:14 | struct S2 |
416418
| main.rs:313:13:313:14 | x3 | | main.rs:257:5:260:5 | struct MyThing |
@@ -431,9 +433,11 @@ inferType
431433
| main.rs:317:16:317:32 | MyThing {...} | | main.rs:257:5:260:5 | struct MyThing |
432434
| main.rs:317:16:317:32 | MyThing {...} | T | main.rs:264:5:265:14 | struct S2 |
433435
| main.rs:317:29:317:30 | S2 | | main.rs:264:5:265:14 | struct S2 |
436+
| main.rs:320:26:320:48 | call_trait_thing_m1(...) | | main.rs:262:5:263:14 | struct S1 |
434437
| main.rs:320:46:320:47 | x3 | | main.rs:257:5:260:5 | struct MyThing |
435438
| main.rs:320:46:320:47 | x3 | T | main.rs:257:5:260:5 | struct MyThing |
436439
| main.rs:320:46:320:47 | x3 | T.T | main.rs:262:5:263:14 | struct S1 |
440+
| main.rs:321:26:321:48 | call_trait_thing_m1(...) | | main.rs:264:5:265:14 | struct S2 |
437441
| main.rs:321:46:321:47 | y3 | | main.rs:257:5:260:5 | struct MyThing |
438442
| main.rs:321:46:321:47 | y3 | T | main.rs:257:5:260:5 | struct MyThing |
439443
| main.rs:321:46:321:47 | y3 | T.T | main.rs:264:5:265:14 | struct S2 |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -512,36 +512,35 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
512512

513513
private module AccessBaseType {
514514
/**
515-
* Holds if inferring types at `a` might depend on the type at `apos`
516-
* having `baseMention` as a transitive base type mention.
515+
* Holds if inferring types at `a` might depend on the type at `path` of
516+
* `apos` having `baseMention` as a transitive base type mention.
517517
*/
518-
private predicate relevantAccess(Access a, AccessPosition apos, Type base) {
518+
private predicate relevantAccess(Access a, AccessPosition apos, TypePath path, Type base) {
519519
exists(Declaration target, DeclarationPosition dpos |
520520
adjustedAccessType(a, apos, target, _, _) and
521-
accessDeclarationPositionMatch(apos, dpos) and
522-
declarationBaseType(target, dpos, base, _, _)
521+
accessDeclarationPositionMatch(apos, dpos)
522+
|
523+
path.isEmpty() and declarationBaseType(target, dpos, base, _, _)
524+
or
525+
typeParameterConstraintHasTypeParameter(target, dpos, path, _, base, _, _)
523526
)
524527
}
525528

526529
pragma[nomagic]
527-
private Type inferRootType(Access a, AccessPosition apos) {
528-
relevantAccess(a, apos, _) and
529-
result = a.getInferredType(apos, TypePath::nil())
530-
}
531-
532-
pragma[nomagic]
533-
private Type inferTypeAt(Access a, AccessPosition apos, TypeParameter tp, TypePath suffix) {
534-
relevantAccess(a, apos, _) and
530+
private Type inferTypeAt(
531+
Access a, AccessPosition apos, TypePath prefix, TypeParameter tp, TypePath suffix
532+
) {
533+
relevantAccess(a, apos, prefix, _) and
535534
exists(TypePath path0 |
536-
result = a.getInferredType(apos, path0) and
535+
result = a.getInferredType(apos, prefix.append(path0)) and
537536
path0.isCons(tp, suffix)
538537
)
539538
}
540539

541540
/**
542-
* Holds if `baseMention` is a (transitive) base type mention of the type of
543-
* `a` at position `apos`, and `t` is mentioned (implicitly) at `path` inside
544-
* `base`. For example, in
541+
* Holds if `baseMention` is a (transitive) base type mention of the
542+
* type of `a` at position `apos` at path `pathToSub`, and `t` is
543+
* mentioned (implicitly) at `path` inside `base`. For example, in
545544
*
546545
* ```csharp
547546
* class C<T1> { }
@@ -570,17 +569,18 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
570569
*/
571570
pragma[nomagic]
572571
predicate hasBaseTypeMention(
573-
Access a, AccessPosition apos, TypeMention baseMention, TypePath path, Type t
572+
Access a, AccessPosition apos, TypePath pathToSub, TypeMention baseMention, TypePath path,
573+
Type t
574574
) {
575-
relevantAccess(a, apos, resolveTypeMentionRoot(baseMention)) and
576-
exists(Type sub | sub = inferRootType(a, apos) |
575+
relevantAccess(a, apos, pathToSub, resolveTypeMentionRoot(baseMention)) and
576+
exists(Type sub | sub = a.getInferredType(apos, pathToSub) |
577577
not t = sub.getATypeParameter() and
578578
baseTypeMentionHasTypeAt(sub, baseMention, path, t)
579579
or
580580
exists(TypePath prefix, TypePath suffix, TypeParameter tp |
581581
tp = sub.getATypeParameter() and
582582
baseTypeMentionHasTypeAt(sub, baseMention, prefix, tp) and
583-
t = inferTypeAt(a, apos, tp, suffix) and
583+
t = inferTypeAt(a, apos, pathToSub, tp, suffix) and
584584
path = prefix.append(suffix)
585585
)
586586
)
@@ -596,7 +596,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
596596
Access a, AccessPosition apos, Type base, TypePath path, Type t
597597
) {
598598
exists(TypeMention tm |
599-
AccessBaseType::hasBaseTypeMention(a, apos, tm, path, t) and
599+
AccessBaseType::hasBaseTypeMention(a, apos, TypePath::nil(), tm, path, t) and
600600
base = resolveTypeMentionRoot(tm)
601601
)
602602
}
@@ -671,6 +671,58 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
671671
t = getTypeArgument(a, target, tp, path)
672672
}
673673

674+
/**
675+
* Holds if `tp1` and `tp2` are distinct type parameters of `target`, the
676+
* declared type at `apos` mentions `tp1` at `path1`, `tp1` has a base
677+
* type mention of type `constrant` that mentions `tp2` at the path
678+
* `path2`.
679+
*
680+
* For this example
681+
* ```csharp
682+
* interface IFoo<A> { }
683+
* void M<T1, T2>(T2 item) where T2 : IFoo<T1> { }
684+
* ```
685+
* with the method declaration being the target and the for the first
686+
* parameter position, we have the following
687+
* - `path1 = ""`,
688+
* - `tp1 = T2`,
689+
* - `constraint = IFoo`,
690+
* - `path2 = "A"`,
691+
* - `tp2 = T1`
692+
*/
693+
pragma[nomagic]
694+
private predicate typeParameterConstraintHasTypeParameter(
695+
Declaration target, DeclarationPosition dpos, TypePath path1, TypeParameter tp1,
696+
Type constraint, TypePath path2, TypeParameter tp2
697+
) {
698+
tp1 = target.getTypeParameter(_) and
699+
tp2 = target.getTypeParameter(_) and
700+
tp1 != tp2 and
701+
tp1 = target.getDeclaredType(dpos, path1) and
702+
exists(TypeMention tm |
703+
tm = getABaseTypeMention(tp1) and
704+
tm.resolveTypeAt(path2) = tp2 and
705+
constraint = resolveTypeMentionRoot(tm)
706+
)
707+
}
708+
709+
pragma[nomagic]
710+
private predicate typeConstraintBaseTypeMatch(
711+
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
712+
) {
713+
not exists(getTypeArgument(a, target, tp, _)) and
714+
target = a.getTarget() and
715+
exists(
716+
TypeMention base, AccessPosition apos, DeclarationPosition dpos, TypePath pathToTp,
717+
TypePath pathToTp2
718+
|
719+
accessDeclarationPositionMatch(apos, dpos) and
720+
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _,
721+
resolveTypeMentionRoot(base), pathToTp, tp) and
722+
AccessBaseType::hasBaseTypeMention(a, apos, pathToTp2, base, pathToTp.append(path), t)
723+
)
724+
}
725+
674726
pragma[inline]
675727
private predicate typeMatch(
676728
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
@@ -684,6 +736,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
684736
or
685737
// We can infer the type of `tp` by going up the type hiearchy
686738
baseTypeMatch(a, target, path, t, tp)
739+
or
740+
// We can infer the type of `tp` by a type bound
741+
typeConstraintBaseTypeMatch(a, target, path, t, tp)
687742
}
688743

689744
/**

0 commit comments

Comments
 (0)