Skip to content

Commit 3fa4ea4

Browse files
committed
Rust: Improve performance of type inference
1 parent 5f524ef commit 3fa4ea4

File tree

2 files changed

+88
-35
lines changed

2 files changed

+88
-35
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,6 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
213213
path1 = path2
214214
)
215215
or
216-
n2 =
217-
any(PrefixExpr pe |
218-
pe.getOperatorName() = "*" and
219-
pe.getExpr() = n1 and
220-
path1 = TypePath::cons(TRefTypeParameter(), path2)
221-
)
222-
or
223216
n1 = n2.(ParenExpr).getExpr() and
224217
path1 = path2
225218
or
@@ -239,12 +232,36 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
239232
)
240233
}
241234

235+
bindingset[path1]
236+
private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
237+
typeEquality(n1, path1, n2, path2)
238+
or
239+
n2 =
240+
any(PrefixExpr pe |
241+
pe.getOperatorName() = "*" and
242+
pe.getExpr() = n1 and
243+
path1 = TypePath::consInverse(TRefTypeParameter(), path2)
244+
)
245+
}
246+
247+
bindingset[path2]
248+
private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
249+
typeEquality(n1, path1, n2, path2)
250+
or
251+
n2 =
252+
any(PrefixExpr pe |
253+
pe.getOperatorName() = "*" and
254+
pe.getExpr() = n1 and
255+
path1 = TypePath::cons(TRefTypeParameter(), path2)
256+
)
257+
}
258+
242259
pragma[nomagic]
243260
private Type inferTypeEquality(AstNode n, TypePath path) {
244261
exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) |
245-
typeEquality(n, path, n2, path2)
262+
typeEqualityRight(n, path, n2, path2)
246263
or
247-
typeEquality(n2, path2, n, path)
264+
typeEqualityLeft(n2, path2, n, path)
248265
)
249266
}
250267

@@ -909,7 +926,7 @@ private Type inferRefExprType(Expr e, TypePath path) {
909926
e = re.getExpr() and
910927
exists(TypePath exprPath, TypePath refPath, Type exprType |
911928
result = inferType(re, exprPath) and
912-
exprPath = TypePath::cons(TRefTypeParameter(), refPath) and
929+
exprPath = TypePath::consInverse(TRefTypeParameter(), refPath) and
913930
exprType = inferType(e)
914931
|
915932
if exprType = TRefType()
@@ -924,7 +941,7 @@ private Type inferRefExprType(Expr e, TypePath path) {
924941
pragma[nomagic]
925942
private Type inferTryExprType(TryExpr te, TypePath path) {
926943
exists(TypeParam tp |
927-
result = inferType(te.getExpr(), TypePath::cons(TTypeParamTypeParameter(tp), path))
944+
result = inferType(te.getExpr(), TypePath::consInverse(TTypeParamTypeParameter(tp), path))
928945
|
929946
tp = any(ResultEnum r).getGenericParamList().getGenericParam(0)
930947
or
@@ -1000,7 +1017,7 @@ private module Cached {
10001017
pragma[nomagic]
10011018
Type getTypeAt(TypePath path) {
10021019
exists(TypePath path0 | result = inferType(this, path0) |
1003-
path0 = TypePath::cons(TRefTypeParameter(), path)
1020+
path0 = TypePath::consInverse(TRefTypeParameter(), path)
10041021
or
10051022
not path0.isCons(TRefTypeParameter(), _) and
10061023
not (path0.isEmpty() and result = TRefType()) and

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,29 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
181181
/** Holds if this type path is empty. */
182182
predicate isEmpty() { this = "" }
183183

184+
/** Gets the length of this path, assuming the length is at least 2. */
185+
bindingset[this]
186+
pragma[inline_late]
187+
private int length2() {
188+
// Same as
189+
// `result = strictcount(this.indexOf(".")) + 1`
190+
// but performs better because it doesn't use an aggregate
191+
result = this.regexpReplaceAll("[0-9]+", "").length() + 1
192+
}
193+
184194
/** Gets the length of this path. */
185195
bindingset[this]
186196
pragma[inline_late]
187197
int length() {
188-
this.isEmpty() and result = 0
189-
or
190-
result = strictcount(this.indexOf(".")) + 1
198+
if this.isEmpty()
199+
then result = 0
200+
else
201+
if exists(TypeParameter::decode(this))
202+
then result = 1
203+
else result = this.length2()
191204
}
192205

193206
/** Gets the path obtained by appending `suffix` onto this path. */
194-
bindingset[suffix, result]
195-
bindingset[this, result]
196207
bindingset[this, suffix]
197208
TypePath append(TypePath suffix) {
198209
if this.isEmpty()
@@ -202,22 +213,37 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
202213
then result = this
203214
else (
204215
result = this + "." + suffix and
205-
not result.length() > getTypePathLimit()
216+
(
217+
not exists(getTypePathLimit())
218+
or
219+
result.length2() <= getTypePathLimit()
220+
)
221+
)
222+
}
223+
224+
/**
225+
* Gets the path obtained by appending `suffix` onto this path.
226+
*
227+
* Unlike `append`, this predicate has `result` in the binding set,
228+
* so there is no need to check the length of `result`.
229+
*/
230+
bindingset[this, result]
231+
TypePath appendInverse(TypePath suffix) {
232+
if result.isEmpty()
233+
then this.isEmpty() and suffix.isEmpty()
234+
else
235+
if this.isEmpty()
236+
then suffix = result
237+
else (
238+
result = this and suffix.isEmpty()
239+
or
240+
result = this + "." + suffix
206241
)
207242
}
208243

209244
/** Holds if this path starts with `tp`, followed by `suffix`. */
210245
bindingset[this]
211-
predicate isCons(TypeParameter tp, TypePath suffix) {
212-
tp = TypeParameter::decode(this) and
213-
suffix.isEmpty()
214-
or
215-
exists(int first |
216-
first = min(this.indexOf(".")) and
217-
suffix = this.suffix(first + 1) and
218-
tp = TypeParameter::decode(this.prefix(first))
219-
)
220-
}
246+
predicate isCons(TypeParameter tp, TypePath suffix) { this = TypePath::consInverse(tp, suffix) }
221247
}
222248

223249
/** Provides predicates for constructing `TypePath`s. */
@@ -232,9 +258,17 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
232258
* Gets the type path obtained by appending the singleton type path `tp`
233259
* onto `suffix`.
234260
*/
235-
bindingset[result]
236261
bindingset[suffix]
237262
TypePath cons(TypeParameter tp, TypePath suffix) { result = singleton(tp).append(suffix) }
263+
264+
/**
265+
* Gets the type path obtained by appending the singleton type path `tp`
266+
* onto `suffix`.
267+
*/
268+
bindingset[result]
269+
TypePath consInverse(TypeParameter tp, TypePath suffix) {
270+
result = singleton(tp).appendInverse(suffix)
271+
}
238272
}
239273

240274
/**
@@ -556,7 +590,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
556590
TypeMention tm1, TypeMention tm2, TypeParameter tp, TypePath path, Type t
557591
) {
558592
exists(TypePath prefix |
559-
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.append(path))
593+
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.appendInverse(path))
560594
)
561595
}
562596

@@ -899,7 +933,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
899933
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
900934
tp = target.getDeclaredType(dpos, pathToTypeParam) and
901935
accessDeclarationPositionMatch(apos, dpos) and
902-
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t)
936+
adjustedAccessType(a, apos, target, pathToTypeParam.appendInverse(path), t)
903937
)
904938
}
905939

@@ -998,7 +1032,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
9981032

9991033
RelevantAccess() { this = MkRelevantAccess(a, apos, path) }
10001034

1001-
Type getTypeAt(TypePath suffix) { a.getInferredType(apos, path.append(suffix)) = result }
1035+
Type getTypeAt(TypePath suffix) {
1036+
a.getInferredType(apos, path.appendInverse(suffix)) = result
1037+
}
10021038

10031039
/** Holds if this relevant access has the type `type` and should satisfy `constraint`. */
10041040
predicate hasTypeConstraint(Type type, Type constraint) {
@@ -1077,7 +1113,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
10771113
t0 = abs.getATypeParameter() and
10781114
exists(TypePath path3, TypePath suffix |
10791115
sub.resolveTypeAt(path3) = t0 and
1080-
at.getTypeAt(path3.append(suffix)) = t and
1116+
at.getTypeAt(path3.appendInverse(suffix)) = t and
10811117
path = prefix0.append(suffix)
10821118
)
10831119
)
@@ -1149,7 +1185,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
11491185
not exists(getTypeArgument(a, target, tp, _)) and
11501186
target = a.getTarget() and
11511187
exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam |
1152-
accessBaseType(a, apos, base, pathToTypeParam.append(path), t) and
1188+
accessBaseType(a, apos, base, pathToTypeParam.appendInverse(path), t) and
11531189
declarationBaseType(target, dpos, base, pathToTypeParam, tp) and
11541190
accessDeclarationPositionMatch(apos, dpos)
11551191
)
@@ -1217,7 +1253,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
12171253
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _, constraint, pathToTp,
12181254
tp) and
12191255
AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, constraint,
1220-
pathToTp.append(path), t)
1256+
pathToTp.appendInverse(path), t)
12211257
)
12221258
}
12231259

0 commit comments

Comments
 (0)