Skip to content

Commit b4bd65b

Browse files
committed
wip
1 parent 0e917ed commit b4bd65b

File tree

2 files changed

+47
-96
lines changed

2 files changed

+47
-96
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,13 @@ private module MethodCallResolution {
13161316
mc.isMethodCall(name, arity)
13171317
}
13181318

1319+
pragma[nomagic]
1320+
private predicate isNotInherentTarget(Impl impl) {
1321+
IsInstantiationOf<MethodCallCand, FunctionPositionType, MethodCallIsInstantiationOfInput>::isNotInstantiationOf(this,
1322+
impl, _) and
1323+
not impl.hasTrait()
1324+
}
1325+
13191326
/**
13201327
* Holds if this method call has no inherent target, i.e., it does not
13211328
* resolve to a method in an `impl` block for the type of the receiver.
@@ -1328,8 +1335,7 @@ private module MethodCallResolution {
13281335
methodCandidate(rootType, name, arity, impl, _) and
13291336
not impl.hasTrait()
13301337
|
1331-
IsInstantiationOf<MethodCallCand, FunctionPositionType, MethodCallIsInstantiationOfInput>::isNotInstantiationOf(this,
1332-
impl, _)
1338+
this.isNotInherentTarget(impl)
13331339
)
13341340
)
13351341
}
@@ -2000,10 +2006,11 @@ private module OperationResolution {
20002006
}
20012007

20022008
pragma[nomagic]
2003-
predicate isOperation(Type rootType, string name, int arity) {
2009+
predicate isOperation(Type rootType, Trait trait, string name, int arity) {
20042010
name = this.(Call).getMethodName() and
20052011
arity = this.(Call).getNumberOfArguments() and
2006-
rootType = this.getTypeAt(TypePath::nil())
2012+
rootType = this.getTypeAt(TypePath::nil()) and
2013+
trait = this.(Call).getTrait()
20072014
}
20082015

20092016
pragma[nomagic]
@@ -2055,9 +2062,13 @@ private module OperationResolution {
20552062

20562063
pragma[nomagic]
20572064
predicate potentialInstantiationOf(Op op, TypeAbstraction abs, FunctionPositionType constraint) {
2058-
exists(Type rootType, string name, int arity |
2059-
op.isOperation(rootType, name, arity) and
2060-
methodCandidateTrait(rootType, op.(Call).getTrait(), name, arity, abs, constraint)
2065+
exists(Type rootType, Trait trait, string name, int arity |
2066+
op.isOperation(rootType, trait, name, arity) and
2067+
MethodCallResolution::methodCandidate(rootType, name, arity, abs, constraint)
2068+
|
2069+
trait = abs.(ImplItemNode).resolveTraitTy()
2070+
or
2071+
trait = abs
20612072
)
20622073
}
20632074

@@ -2330,7 +2341,17 @@ private module FieldExprMatchingInput implements MatchingInputSig {
23302341
}
23312342

23322343
Type getInferredType(AccessPosition apos, TypePath path) {
2333-
result = inferType(this.getNodeAt(apos), path)
2344+
exists(TypePath path0 | result = inferType(this.getNodeAt(apos), path0) |
2345+
if apos.isSelf()
2346+
then
2347+
// adjust for implicit deref
2348+
path0.isCons(TRefTypeParameter(), path)
2349+
or
2350+
not path0.isCons(TRefTypeParameter(), _) and
2351+
not (result = TRefType() and path0.isEmpty()) and
2352+
path = path0
2353+
else path = path0
2354+
)
23342355
}
23352356

23362357
Declaration getTarget() {
@@ -2347,28 +2368,6 @@ private module FieldExprMatchingInput implements MatchingInputSig {
23472368
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
23482369
apos = dpos
23492370
}
2350-
2351-
bindingset[apos, target, path, t]
2352-
pragma[inline_late]
2353-
predicate adjustAccessType(
2354-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
2355-
) {
2356-
exists(target) and
2357-
if apos.isSelf()
2358-
then
2359-
// adjust for implicit deref
2360-
path.isCons(TRefTypeParameter(), pathAdj) and
2361-
tAdj = t
2362-
or
2363-
not path.isCons(TRefTypeParameter(), _) and
2364-
not (t = TRefType() and path.isEmpty()) and
2365-
pathAdj = path and
2366-
tAdj = t
2367-
else (
2368-
pathAdj = path and
2369-
tAdj = t
2370-
)
2371-
}
23722371
}
23732372

23742373
private module FieldExprMatching = Matching<FieldExprMatchingInput>;

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,29 +1103,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11031103
bindingset[apos]
11041104
bindingset[dpos]
11051105
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos);
1106-
1107-
/**
1108-
* Holds if matching an inferred type `t` at `path` inside an access at `apos`
1109-
* against the declaration `target` means that the type should be adjusted to
1110-
* `tAdj` at `pathAdj`.
1111-
*
1112-
* For example, in
1113-
*
1114-
* ```csharp
1115-
* void M(int? i) {}
1116-
* M(42);
1117-
* ```
1118-
*
1119-
* the inferred type of `42` is `int`, but it should be adjusted to `int?`
1120-
* when matching against `M`.
1121-
*/
1122-
bindingset[apos, target, path, t]
1123-
default predicate adjustAccessType(
1124-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
1125-
) {
1126-
pathAdj = path and
1127-
tAdj = t
1128-
}
11291106
}
11301107

11311108
/**
@@ -1142,14 +1119,11 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11421119
* in `a` is `t` after adjustment by `target`.
11431120
*/
11441121
pragma[nomagic]
1145-
private predicate adjustedAccessType(
1122+
private predicate accessType(
11461123
Access a, State state, AccessPosition apos, Declaration target, TypePath path, Type t
11471124
) {
11481125
target = a.getTarget(state) and
1149-
exists(TypePath path0, Type t0 |
1150-
t0 = a.getInferredType(state, apos, path0) and
1151-
adjustAccessType(apos, target, path0, t0, path, t)
1152-
)
1126+
t = a.getInferredType(state, apos, path)
11531127
}
11541128

11551129
/**
@@ -1182,7 +1156,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11821156
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
11831157
tp = target.getDeclaredType(dpos, pathToTypeParam) and
11841158
accessDeclarationPositionMatch(apos, dpos) and
1185-
adjustedAccessType(a, state, apos, target, pathToTypeParam.appendInverse(path), t)
1159+
accessType(a, state, apos, target, pathToTypeParam.appendInverse(path), t)
11861160
)
11871161
}
11881162

@@ -1193,7 +1167,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11931167
*/
11941168
private predicate relevantAccess(Access a, State state, AccessPosition apos, Type base) {
11951169
exists(Declaration target, DeclarationPosition dpos |
1196-
adjustedAccessType(a, state, apos, target, _, _) and
1170+
accessType(a, state, apos, target, _, _) and
11971171
accessDeclarationPositionMatch(apos, dpos) and
11981172
declarationBaseType(target, dpos, base, _, _)
11991173
)
@@ -1268,10 +1242,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
12681242
}
12691243

12701244
private newtype TRelevantAccess =
1271-
MkRelevantAccess(
1272-
Access a, State state, Declaration target, AccessPosition apos, TypePath path
1273-
) {
1274-
relevantAccessConstraint(a, state, target, apos, path, _)
1245+
MkRelevantAccess(Access a, State state, AccessPosition apos, TypePath path) {
1246+
relevantAccessConstraint(a, state, _, apos, path, _)
12751247
}
12761248

12771249
/**
@@ -1281,18 +1253,19 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
12811253
private class RelevantAccess extends MkRelevantAccess {
12821254
Access a;
12831255
State state;
1284-
Declaration target;
12851256
AccessPosition apos;
12861257
TypePath path;
12871258

1288-
RelevantAccess() { this = MkRelevantAccess(a, state, target, apos, path) }
1259+
RelevantAccess() { this = MkRelevantAccess(a, state, apos, path) }
12891260

12901261
Type getTypeAt(TypePath suffix) {
1291-
adjustedAccessType(a, state, apos, target, path.appendInverse(suffix), result)
1262+
accessType(a, state, apos, _, path.appendInverse(suffix), result)
12921263
}
12931264

12941265
/** Holds if this relevant access should satisfy `constraint`. */
1295-
Type getConstraint() { relevantAccessConstraint(a, state, target, apos, path, result) }
1266+
Type getConstraint(Declaration target) {
1267+
relevantAccessConstraint(a, state, target, apos, path, result)
1268+
}
12961269

12971270
string toString() {
12981271
result = a.toString() + ", " + apos.toString() + ", " + path.toString()
@@ -1305,16 +1278,20 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
13051278
SatisfiesConstraintInputSig<RelevantAccess>
13061279
{
13071280
predicate relevantConstraint(RelevantAccess at, Type constraint) {
1308-
constraint = at.getConstraint()
1281+
constraint = at.getConstraint(_)
13091282
}
13101283
}
13111284

13121285
predicate satisfiesConstraintType(
13131286
Access a, State state, Declaration target, AccessPosition apos, TypePath prefix,
13141287
Type constraint, TypePath path, Type t
13151288
) {
1316-
SatisfiesConstraint<RelevantAccess, SatisfiesConstraintInput>::satisfiesConstraintType(MkRelevantAccess(a,
1317-
state, target, apos, prefix), constraint, path, t)
1289+
exists(RelevantAccess ra |
1290+
ra = MkRelevantAccess(a, state, apos, prefix) and
1291+
SatisfiesConstraint<RelevantAccess, SatisfiesConstraintInput>::satisfiesConstraintType(ra,
1292+
constraint, path, t) and
1293+
constraint = ra.getConstraint(target)
1294+
)
13181295
}
13191296
}
13201297

@@ -1605,29 +1582,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
16051582
bindingset[apos]
16061583
bindingset[dpos]
16071584
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos);
1608-
1609-
/**
1610-
* Holds if matching an inferred type `t` at `path` inside an access at `apos`
1611-
* against the declaration `target` means that the type should be adjusted to
1612-
* `tAdj` at `pathAdj`.
1613-
*
1614-
* For example, in
1615-
*
1616-
* ```csharp
1617-
* void M(int? i) {}
1618-
* M(42);
1619-
* ```
1620-
*
1621-
* the inferred type of `42` is `int`, but it should be adjusted to `int?`
1622-
* when matching against `M`.
1623-
*/
1624-
bindingset[apos, target, path, t]
1625-
default predicate adjustAccessType(
1626-
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
1627-
) {
1628-
pathAdj = path and
1629-
tAdj = t
1630-
}
16311585
}
16321586

16331587
/**
@@ -1641,8 +1595,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
16411595
private import codeql.util.Unit
16421596
import Input
16431597

1644-
predicate adjustAccessType = Input::adjustAccessType/6;
1645-
16461598
class State = Unit;
16471599

16481600
final private class AccessFinal = Input::Access;

0 commit comments

Comments
 (0)