Skip to content

Commit 8c11138

Browse files
committed
Address review comments
1 parent 5e7cd46 commit 8c11138

File tree

1 file changed

+45
-38
lines changed

1 file changed

+45
-38
lines changed

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ private predicate resolveExtendedCanonicalPath(Resolvable r, CrateOriginOption c
386386
}
387387

388388
/**
389-
* A reference contained in an object. For example a field in a struct.
389+
* A path to a value contained in an object. For example a field name of a struct.
390390
*/
391391
abstract class Content extends TContent {
392392
/** Gets a textual representation of this content. */
@@ -416,34 +416,34 @@ private class VariantCanonicalPath extends MkVariantCanonicalPath {
416416
abstract class VariantContent extends Content { }
417417

418418
/** A tuple variant. */
419-
private class TupleVariantContent extends VariantContent, TTupleVariantContent {
419+
private class VariantPositionContent extends VariantContent, TVariantPositionContent {
420420
private VariantCanonicalPath v;
421421
private int pos_;
422422

423-
TupleVariantContent() { this = TTupleVariantContent(v, pos_) }
423+
VariantPositionContent() { this = TVariantPositionContent(v, pos_) }
424424

425425
VariantCanonicalPath getVariantCanonicalPath(int pos) { result = v and pos = pos_ }
426426

427427
final override string toString() {
428428
// only print indices when the arity is > 1
429-
if exists(TTupleVariantContent(v, 1))
429+
if exists(TVariantPositionContent(v, 1))
430430
then result = v.toString() + "(" + pos_ + ")"
431431
else result = v.toString()
432432
}
433433
}
434434

435435
/** A record variant. */
436-
private class RecordVariantContent extends VariantContent, TRecordVariantContent {
436+
private class VariantFieldContent extends VariantContent, TVariantFieldContent {
437437
private VariantCanonicalPath v;
438438
private string field_;
439439

440-
RecordVariantContent() { this = TRecordVariantContent(v, field_) }
440+
VariantFieldContent() { this = TVariantFieldContent(v, field_) }
441441

442442
VariantCanonicalPath getVariantCanonicalPath(string field) { result = v and field = field_ }
443443

444444
final override string toString() {
445445
// only print field when the arity is > 1
446-
if strictcount(string f | exists(TRecordVariantContent(v, f))) > 1
446+
if strictcount(string f | exists(TVariantFieldContent(v, f))) > 1
447447
then result = v.toString() + "{" + field_ + "}"
448448
else result = v.toString()
449449
}
@@ -461,7 +461,7 @@ abstract class ContentSet extends TContentSet {
461461
abstract Content getAReadContent();
462462
}
463463

464-
private class SingletonContentSet extends ContentSet, TSingletonContentSet {
464+
final private class SingletonContentSet extends ContentSet, TSingletonContentSet {
465465
private Content c;
466466

467467
SingletonContentSet() { this = TSingletonContentSet(c) }
@@ -539,21 +539,18 @@ module RustDataFlow implements InputSig<Location> {
539539
final class ReturnKind = ReturnKindAlias;
540540

541541
pragma[nomagic]
542-
private predicate callResolveExtendedCanonicalPath(
543-
CallExprBase call, CrateOriginOption crate, string path
544-
) {
545-
exists(Resolvable r | resolveExtendedCanonicalPath(r, crate, path) |
546-
r = call.(MethodCallExpr)
547-
or
548-
r = call.(CallExpr).getExpr().(PathExpr).getPath()
549-
)
542+
private Resolvable getCallResolvable(CallExprBase call) {
543+
result = call.(MethodCallExpr)
544+
or
545+
result = call.(CallExpr).getExpr().(PathExpr).getPath()
550546
}
551547

552548
/** Gets a viable implementation of the target of the given `Call`. */
553549
DataFlowCallable viableCallable(DataFlowCall call) {
554-
exists(string path, CrateOriginOption crate |
550+
exists(Resolvable r, string path, CrateOriginOption crate |
555551
hasExtendedCanonicalPath(result.asCfgScope(), crate, path) and
556-
callResolveExtendedCanonicalPath(call.asCallBaseExprCfgNode().getExpr(), crate, path)
552+
r = getCallResolvable(call.asCallBaseExprCfgNode().getExpr()) and
553+
resolveExtendedCanonicalPath(r, crate, path)
557554
)
558555
}
559556

@@ -581,7 +578,7 @@ module RustDataFlow implements InputSig<Location> {
581578

582579
predicate forceHighPrecision(Content c) { none() }
583580

584-
final class ContentApprox = Content; // todo
581+
final class ContentApprox = Content; // TODO: Implement if needed
585582

586583
ContentApprox getContentApprox(Content c) { result = c }
587584

@@ -621,6 +618,10 @@ module RustDataFlow implements InputSig<Location> {
621618
// TODO: Remove once library types are extracted
622619
not p.hasQualifier() and
623620
v = MkVariantCanonicalPath(_, "crate::std::option::Option", p.getPart().getNameRef().getText())
621+
or
622+
// TODO: Remove once library types are extracted
623+
not p.hasQualifier() and
624+
v = MkVariantCanonicalPath(_, "crate::std::result::Result", p.getPart().getNameRef().getText())
624625
}
625626

626627
/** Holds if `p` destructs an enum variant `v`. */
@@ -642,22 +643,19 @@ module RustDataFlow implements InputSig<Location> {
642643
*/
643644
predicate readStep(Node node1, ContentSet cs, Node node2) {
644645
exists(Content c | c = cs.(SingletonContentSet).getContent() |
645-
node1.asPat() =
646-
any(TupleStructPatCfgNode pat, int pos |
647-
tupleVariantDestruction(pat.getPat(), c.(TupleVariantContent).getVariantCanonicalPath(pos)) and
648-
node2.asPat() = pat.getField(pos)
649-
|
650-
pat
651-
)
646+
exists(TupleStructPatCfgNode pat, int pos |
647+
pat = node1.asPat() and
648+
tupleVariantDestruction(pat.getPat(),
649+
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
650+
node2.asPat() = pat.getField(pos)
651+
)
652652
or
653-
node1.asPat() =
654-
any(RecordPatCfgNode pat, string field |
655-
recordVariantDestruction(pat.getPat(),
656-
c.(RecordVariantContent).getVariantCanonicalPath(field)) and
657-
node2.asPat() = pat.getFieldPat(field)
658-
|
659-
pat
660-
)
653+
exists(RecordPatCfgNode pat, string field |
654+
pat = node1.asPat() and
655+
recordVariantDestruction(pat.getPat(),
656+
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
657+
node2.asPat() = pat.getFieldPat(field)
658+
)
661659
)
662660
}
663661

@@ -683,7 +681,7 @@ module RustDataFlow implements InputSig<Location> {
683681
node2.asExpr() =
684682
any(CallExprCfgNode call, int pos |
685683
tupleVariantConstruction(call.getCallExpr(),
686-
c.(TupleVariantContent).getVariantCanonicalPath(pos)) and
684+
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
687685
node1.asExpr() = call.getArgument(pos)
688686
|
689687
call
@@ -692,7 +690,7 @@ module RustDataFlow implements InputSig<Location> {
692690
node2.asExpr() =
693691
any(RecordExprCfgNode re, string field |
694692
recordVariantConstruction(re.getRecordExpr(),
695-
c.(RecordVariantContent).getVariantCanonicalPath(field)) and
693+
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
696694
node1.asExpr() = re.getFieldExpr(field)
697695
|
698696
re
@@ -806,18 +804,27 @@ private module Cached {
806804
crate.isNone() and
807805
path = "crate::std::option::Option" and
808806
name = "Some"
807+
or
808+
// TODO: Remove once library types are extracted
809+
crate.isNone() and
810+
path = "crate::std::result::Result" and
811+
name = ["Ok", "Err"]
809812
}
810813

811814
cached
812815
newtype TContent =
813-
TTupleVariantContent(VariantCanonicalPath v, int pos) {
816+
TVariantPositionContent(VariantCanonicalPath v, int pos) {
814817
pos in [0 .. v.getVariant().getFieldList().(TupleFieldList).getNumberOfFields() - 1]
815818
or
816819
// TODO: Remove once library types are extracted
817820
v = MkVariantCanonicalPath(_, "crate::std::option::Option", "Some") and
818821
pos = 0
822+
or
823+
// TODO: Remove once library types are extracted
824+
v = MkVariantCanonicalPath(_, "crate::std::result::Result", ["Ok", "Err"]) and
825+
pos = 0
819826
} or
820-
TRecordVariantContent(VariantCanonicalPath v, string field) {
827+
TVariantFieldContent(VariantCanonicalPath v, string field) {
821828
field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText()
822829
}
823830

0 commit comments

Comments
 (0)