Skip to content

Commit cb8a603

Browse files
committed
wip
1 parent 0e917ed commit cb8a603

File tree

2 files changed

+49
-109
lines changed

2 files changed

+49
-109
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,13 @@ private module MethodCallResolution {
13161316
mc.isMethodCall(name, arity)
13171317
}
13181318

1319+
pragma[nomagic]
1320+
private predicate isNotInherentTarget(Impl impl) {
1321+
IsInstantiationOf<MethodCallCand, FunctionPositionType, MethodCallIsInstantiationOfInput>::isNotInstantiationOf(this,
1322+
impl, _) and
1323+
not impl.hasTrait()
1324+
}
1325+
13191326
/**
13201327
* Holds if this method call has no inherent target, i.e., it does not
13211328
* resolve to a method in an `impl` block for the type of the receiver.
@@ -1328,8 +1335,7 @@ private module MethodCallResolution {
13281335
methodCandidate(rootType, name, arity, impl, _) and
13291336
not impl.hasTrait()
13301337
|
1331-
IsInstantiationOf<MethodCallCand, FunctionPositionType, MethodCallIsInstantiationOfInput>::isNotInstantiationOf(this,
1332-
impl, _)
1338+
this.isNotInherentTarget(impl)
13331339
)
13341340
)
13351341
}
@@ -1527,12 +1533,13 @@ private module MethodCallMatchingInput implements MatchingWithStateInputSig {
15271533
)
15281534
}
15291535

1530-
bindingset[state]
1536+
pragma[nomagic]
15311537
Type getInferredType(State state, AccessPosition apos, TypePath path) {
15321538
apos.asArgumentPosition().isSelf() and
15331539
result = this.getInferredSelfType(state, path)
15341540
or
1535-
result = this.getInferredNonSelfType(apos, path)
1541+
result = this.getInferredNonSelfType(apos, path) and
1542+
exists(this.getTarget(state))
15361543
}
15371544

15381545
Declaration getTarget(State state) {
@@ -2000,10 +2007,11 @@ private module OperationResolution {
20002007
}
20012008

20022009
pragma[nomagic]
2003-
predicate isOperation(Type rootType, string name, int arity) {
2010+
predicate isOperation(Type rootType, Trait trait, string name, int arity) {
20042011
name = this.(Call).getMethodName() and
20052012
arity = this.(Call).getNumberOfArguments() and
2006-
rootType = this.getTypeAt(TypePath::nil())
2013+
rootType = this.getTypeAt(TypePath::nil()) and
2014+
trait = this.(Call).getTrait()
20072015
}
20082016

20092017
pragma[nomagic]
@@ -2055,9 +2063,13 @@ private module OperationResolution {
20552063

20562064
pragma[nomagic]
20572065
predicate potentialInstantiationOf(Op op, TypeAbstraction abs, FunctionPositionType constraint) {
2058-
exists(Type rootType, string name, int arity |
2059-
op.isOperation(rootType, name, arity) and
2060-
methodCandidateTrait(rootType, op.(Call).getTrait(), name, arity, abs, constraint)
2066+
exists(Type rootType, Trait trait, string name, int arity |
2067+
op.isOperation(rootType, trait, name, arity) and
2068+
MethodCallResolution::methodCandidate(rootType, name, arity, abs, constraint)
2069+
|
2070+
trait = abs.(ImplItemNode).resolveTraitTy()
2071+
or
2072+
trait = abs
20612073
)
20622074
}
20632075

@@ -2330,7 +2342,17 @@ private module FieldExprMatchingInput implements MatchingInputSig {
23302342
}
23312343

23322344
Type getInferredType(AccessPosition apos, TypePath path) {
2333-
result = inferType(this.getNodeAt(apos), path)
2345+
exists(TypePath path0 | result = inferType(this.getNodeAt(apos), path0) |
2346+
if apos.isSelf()
2347+
then
2348+
// adjust for implicit deref
2349+
path0.isCons(TRefTypeParameter(), path)
2350+
or
2351+
not path0.isCons(TRefTypeParameter(), _) and
2352+
not (result = TRefType() and path0.isEmpty()) and
2353+
path = path0
2354+
else path = path0
2355+
)
23342356
}
23352357

23362358
Declaration getTarget() {
@@ -2347,28 +2369,6 @@ private module FieldExprMatchingInput implements MatchingInputSig {
23472369
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
23482370
apos = dpos
23492371
}
2350-
2351-
bindingset[apos, target, path, t]
2352-
pragma[inline_late]
2353-
predicate adjustAccessType(
2354-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
2355-
) {
2356-
exists(target) and
2357-
if apos.isSelf()
2358-
then
2359-
// adjust for implicit deref
2360-
path.isCons(TRefTypeParameter(), pathAdj) and
2361-
tAdj = t
2362-
or
2363-
not path.isCons(TRefTypeParameter(), _) and
2364-
not (t = TRefType() and path.isEmpty()) and
2365-
pathAdj = path and
2366-
tAdj = t
2367-
else (
2368-
pathAdj = path and
2369-
tAdj = t
2370-
)
2371-
}
23722372
}
23732373

23742374
private module FieldExprMatching = Matching<FieldExprMatchingInput>;

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 17 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
10921092
* For example, if this access is the method call `M(42)`, then the inferred
10931093
* type at argument position `0` is `int`.
10941094
*/
1095-
bindingset[state]
10961095
Type getInferredType(State state, AccessPosition apos, TypePath path);
10971096

10981097
/** Gets the declaration that this access targets in `state`. */
@@ -1103,29 +1102,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11031102
bindingset[apos]
11041103
bindingset[dpos]
11051104
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos);
1106-
1107-
/**
1108-
* Holds if matching an inferred type `t` at `path` inside an access at `apos`
1109-
* against the declaration `target` means that the type should be adjusted to
1110-
* `tAdj` at `pathAdj`.
1111-
*
1112-
* For example, in
1113-
*
1114-
* ```csharp
1115-
* void M(int? i) {}
1116-
* M(42);
1117-
* ```
1118-
*
1119-
* the inferred type of `42` is `int`, but it should be adjusted to `int?`
1120-
* when matching against `M`.
1121-
*/
1122-
bindingset[apos, target, path, t]
1123-
default predicate adjustAccessType(
1124-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
1125-
) {
1126-
pathAdj = path and
1127-
tAdj = t
1128-
}
11291105
}
11301106

11311107
/**
@@ -1137,21 +1113,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11371113
module MatchingWithState<MatchingWithStateInputSig Input> {
11381114
private import Input
11391115

1140-
/**
1141-
* Holds if `a` targets `target` in `state` and the type for `apos` at `path`
1142-
* in `a` is `t` after adjustment by `target`.
1143-
*/
1144-
pragma[nomagic]
1145-
private predicate adjustedAccessType(
1146-
Access a, State state, AccessPosition apos, Declaration target, TypePath path, Type t
1147-
) {
1148-
target = a.getTarget(state) and
1149-
exists(TypePath path0, Type t0 |
1150-
t0 = a.getInferredType(state, apos, path0) and
1151-
adjustAccessType(apos, target, path0, t0, path, t)
1152-
)
1153-
}
1154-
11551116
/**
11561117
* Gets the type of the type argument at `path` in `a` that corresponds to
11571118
* the type parameter `tp` in `target`, if any.
@@ -1182,7 +1143,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11821143
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
11831144
tp = target.getDeclaredType(dpos, pathToTypeParam) and
11841145
accessDeclarationPositionMatch(apos, dpos) and
1185-
adjustedAccessType(a, state, apos, target, pathToTypeParam.appendInverse(path), t)
1146+
target = a.getTarget(state) and
1147+
t = a.getInferredType(state, apos, pathToTypeParam.appendInverse(path))
11861148
)
11871149
}
11881150

@@ -1193,7 +1155,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11931155
*/
11941156
private predicate relevantAccess(Access a, State state, AccessPosition apos, Type base) {
11951157
exists(Declaration target, DeclarationPosition dpos |
1196-
adjustedAccessType(a, state, apos, target, _, _) and
1158+
target = a.getTarget(state) and
11971159
accessDeclarationPositionMatch(apos, dpos) and
11981160
declarationBaseType(target, dpos, base, _, _)
11991161
)
@@ -1268,10 +1230,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
12681230
}
12691231

12701232
private newtype TRelevantAccess =
1271-
MkRelevantAccess(
1272-
Access a, State state, Declaration target, AccessPosition apos, TypePath path
1273-
) {
1274-
relevantAccessConstraint(a, state, target, apos, path, _)
1233+
MkRelevantAccess(Access a, State state, AccessPosition apos, TypePath path) {
1234+
relevantAccessConstraint(a, state, _, apos, path, _)
12751235
}
12761236

12771237
/**
@@ -1281,18 +1241,19 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
12811241
private class RelevantAccess extends MkRelevantAccess {
12821242
Access a;
12831243
State state;
1284-
Declaration target;
12851244
AccessPosition apos;
12861245
TypePath path;
12871246

1288-
RelevantAccess() { this = MkRelevantAccess(a, state, target, apos, path) }
1247+
RelevantAccess() { this = MkRelevantAccess(a, state, apos, path) }
12891248

12901249
Type getTypeAt(TypePath suffix) {
1291-
adjustedAccessType(a, state, apos, target, path.appendInverse(suffix), result)
1250+
result = a.getInferredType(state, apos, path.appendInverse(suffix))
12921251
}
12931252

12941253
/** Holds if this relevant access should satisfy `constraint`. */
1295-
Type getConstraint() { relevantAccessConstraint(a, state, target, apos, path, result) }
1254+
Type getConstraint(Declaration target) {
1255+
relevantAccessConstraint(a, state, target, apos, path, result)
1256+
}
12961257

12971258
string toString() {
12981259
result = a.toString() + ", " + apos.toString() + ", " + path.toString()
@@ -1305,16 +1266,20 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
13051266
SatisfiesConstraintInputSig<RelevantAccess>
13061267
{
13071268
predicate relevantConstraint(RelevantAccess at, Type constraint) {
1308-
constraint = at.getConstraint()
1269+
constraint = at.getConstraint(_)
13091270
}
13101271
}
13111272

13121273
predicate satisfiesConstraintType(
13131274
Access a, State state, Declaration target, AccessPosition apos, TypePath prefix,
13141275
Type constraint, TypePath path, Type t
13151276
) {
1316-
SatisfiesConstraint<RelevantAccess, SatisfiesConstraintInput>::satisfiesConstraintType(MkRelevantAccess(a,
1317-
state, target, apos, prefix), constraint, path, t)
1277+
exists(RelevantAccess ra |
1278+
ra = MkRelevantAccess(a, state, apos, prefix) and
1279+
SatisfiesConstraint<RelevantAccess, SatisfiesConstraintInput>::satisfiesConstraintType(ra,
1280+
constraint, path, t) and
1281+
constraint = ra.getConstraint(target)
1282+
)
13181283
}
13191284
}
13201285

@@ -1605,29 +1570,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
16051570
bindingset[apos]
16061571
bindingset[dpos]
16071572
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos);
1608-
1609-
/**
1610-
* Holds if matching an inferred type `t` at `path` inside an access at `apos`
1611-
* against the declaration `target` means that the type should be adjusted to
1612-
* `tAdj` at `pathAdj`.
1613-
*
1614-
* For example, in
1615-
*
1616-
* ```csharp
1617-
* void M(int? i) {}
1618-
* M(42);
1619-
* ```
1620-
*
1621-
* the inferred type of `42` is `int`, but it should be adjusted to `int?`
1622-
* when matching against `M`.
1623-
*/
1624-
bindingset[apos, target, path, t]
1625-
default predicate adjustAccessType(
1626-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
1627-
) {
1628-
pathAdj = path and
1629-
tAdj = t
1630-
}
16311573
}
16321574

16331575
/**
@@ -1641,8 +1583,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
16411583
private import codeql.util.Unit
16421584
import Input
16431585

1644-
predicate adjustAccessType = Input::adjustAccessType/6;
1645-
16461586
class State = Unit;
16471587

16481588
final private class AccessFinal = Input::Access;

0 commit comments

Comments
 (0)