Skip to content

Commit 8e2beb7

Browse files
authored
Merge pull request github#18131 from paldepind/rust-field-flow
Rust: Data flow through tuple and struct fields
2 parents 8375c49 + e1c65aa commit 8e2beb7

File tree

4 files changed

+709
-353
lines changed

4 files changed

+709
-353
lines changed

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 146 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ module Node {
252252
* Nodes corresponding to AST elements, for example `ExprNode`, usually refer
253253
* to the value before the update.
254254
*/
255-
final class PostUpdateNode extends Node, TArgumentPostUpdateNode {
255+
final class PostUpdateNode extends Node, TExprPostUpdateNode {
256256
private ExprCfgNode n;
257257

258-
PostUpdateNode() { this = TArgumentPostUpdateNode(n) }
258+
PostUpdateNode() { this = TExprPostUpdateNode(n) }
259259

260260
/** Gets the node before the state update. */
261261
Node getPreUpdateNode() { result = TExprNode(n) }
@@ -264,7 +264,7 @@ module Node {
264264

265265
final override Location getLocation() { result = n.getLocation() }
266266

267-
final override string toString() { result = n.toString() }
267+
final override string toString() { result = "[post] " + n.toString() }
268268
}
269269

270270
final class CastNode = NaNode;
@@ -286,6 +286,9 @@ module SsaFlow {
286286
or
287287
result.(SsaFlow::ExprNode).getExpr() = n.asExpr()
288288
or
289+
result.(SsaFlow::ExprPostUpdateNode).getExpr() =
290+
n.(Node::PostUpdateNode).getPreUpdateNode().asExpr()
291+
or
289292
n = toParameterNode(result.(SsaFlow::ParameterNode).getParameter())
290293
}
291294

@@ -449,6 +452,54 @@ private class VariantFieldContent extends VariantContent, TVariantFieldContent {
449452
}
450453
}
451454

455+
/** A canonical path pointing to a struct. */
456+
private class StructCanonicalPath extends MkStructCanonicalPath {
457+
CrateOriginOption crate;
458+
string path;
459+
460+
StructCanonicalPath() { this = MkStructCanonicalPath(crate, path) }
461+
462+
/** Gets the underlying struct. */
463+
Struct getStruct() { hasExtendedCanonicalPath(result, crate, path) }
464+
465+
string toString() { result = this.getStruct().getName().getText() }
466+
467+
Location getLocation() { result = this.getStruct().getLocation() }
468+
}
469+
470+
/** Content stored in a field on a struct. */
471+
private class StructFieldContent extends Content, TStructFieldContent {
472+
private StructCanonicalPath s;
473+
private string field_;
474+
475+
StructFieldContent() { this = TStructFieldContent(s, field_) }
476+
477+
StructCanonicalPath getStructCanonicalPath(string field) { result = s and field = field_ }
478+
479+
override string toString() { result = s.toString() + "." + field_.toString() }
480+
}
481+
482+
/**
483+
* Content stored at a position in a tuple.
484+
*
485+
* NOTE: Unlike `struct`s and `enum`s tuples are structural and not nominal,
486+
* hence we don't store a canonical path for them.
487+
*/
488+
private class TuplePositionContent extends Content, TTuplePositionContent {
489+
private int pos;
490+
491+
TuplePositionContent() { this = TTuplePositionContent(pos) }
492+
493+
int getPosition() { result = pos }
494+
495+
override string toString() { result = "tuple." + pos.toString() }
496+
}
497+
498+
/** Holds if `access` indexes a tuple at an index corresponding to `c`. */
499+
private predicate fieldTuplePositionContent(FieldExprCfgNode access, TuplePositionContent c) {
500+
access.getNameRef().getText().toInt() = c.getPosition()
501+
}
502+
452503
/** A value that represents a set of `Content`s. */
453504
abstract class ContentSet extends TContentSet {
454505
/** Gets a textual representation of this element. */
@@ -597,6 +648,14 @@ module RustDataFlow implements InputSig<Location> {
597648
*/
598649
predicate jumpStep(Node node1, Node node2) { none() }
599650

651+
/** Holds if path `p` resolves to struct `s`. */
652+
private predicate pathResolveToStructCanonicalPath(Path p, StructCanonicalPath s) {
653+
exists(CrateOriginOption crate, string path |
654+
resolveExtendedCanonicalPath(p, crate, path) and
655+
s = MkStructCanonicalPath(crate, path)
656+
)
657+
}
658+
600659
/** Holds if path `p` resolves to variant `v`. */
601660
private predicate pathResolveToVariantCanonicalPath(Path p, VariantCanonicalPath v) {
602661
exists(CrateOriginOption crate, string path |
@@ -625,6 +684,12 @@ module RustDataFlow implements InputSig<Location> {
625684
pathResolveToVariantCanonicalPath(p.getPath(), v)
626685
}
627686

687+
/** Holds if `p` destructs a struct `s`. */
688+
pragma[nomagic]
689+
private predicate structDestruction(RecordPat p, StructCanonicalPath s) {
690+
pathResolveToStructCanonicalPath(p.getPath(), s)
691+
}
692+
628693
/**
629694
* Holds if data can flow from `node1` to `node2` via a read of `c`. Thus,
630695
* `node1` references an object with a content `c.getAReadContent()` whose
@@ -641,10 +706,24 @@ module RustDataFlow implements InputSig<Location> {
641706
or
642707
exists(RecordPatCfgNode pat, string field |
643708
pat = node1.asPat() and
644-
recordVariantDestruction(pat.getPat(),
645-
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
709+
(
710+
// Pattern destructs a struct-like variant.
711+
recordVariantDestruction(pat.getPat(),
712+
c.(VariantFieldContent).getVariantCanonicalPath(field))
713+
or
714+
// Pattern destructs a struct.
715+
structDestruction(pat.getPat(), c.(StructFieldContent).getStructCanonicalPath(field))
716+
) and
646717
node2.asPat() = pat.getFieldPat(field)
647718
)
719+
or
720+
exists(FieldExprCfgNode access |
721+
// Read of a tuple entry
722+
fieldTuplePositionContent(access, c) and
723+
// TODO: Handle read of a struct field.
724+
node1.asExpr() = access.getExpr() and
725+
node2.asExpr() = access
726+
)
648727
)
649728
}
650729

@@ -660,30 +739,55 @@ module RustDataFlow implements InputSig<Location> {
660739
pathResolveToVariantCanonicalPath(re.getPath(), v)
661740
}
662741

742+
/** Holds if `re` constructs a struct value of type `s`. */
743+
pragma[nomagic]
744+
private predicate structConstruction(RecordExpr re, StructCanonicalPath s) {
745+
pathResolveToStructCanonicalPath(re.getPath(), s)
746+
}
747+
748+
private predicate tupleAssignment(Node node1, Node node2, TuplePositionContent c) {
749+
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
750+
assignment.getLhs() = access and
751+
fieldTuplePositionContent(access, c) and
752+
node1.asExpr() = assignment.getRhs() and
753+
node2.asExpr() = access.getExpr()
754+
)
755+
}
756+
663757
/**
664758
* Holds if data can flow from `node1` to `node2` via a store into `c`. Thus,
665759
* `node2` references an object with a content `c.getAStoreContent()` that
666760
* contains the value of `node1`.
667761
*/
668762
predicate storeStep(Node node1, ContentSet cs, Node node2) {
669763
exists(Content c | c = cs.(SingletonContentSet).getContent() |
670-
node2.asExpr() =
671-
any(CallExprCfgNode call, int pos |
672-
tupleVariantConstruction(call.getCallExpr(),
673-
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
674-
node1.asExpr() = call.getArgument(pos)
675-
|
676-
call
677-
)
764+
exists(CallExprCfgNode call, int pos |
765+
tupleVariantConstruction(call.getCallExpr(),
766+
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
767+
node1.asExpr() = call.getArgument(pos) and
768+
node2.asExpr() = call
769+
)
678770
or
679-
node2.asExpr() =
680-
any(RecordExprCfgNode re, string field |
771+
exists(RecordExprCfgNode re, string field |
772+
(
773+
// Expression is for a struct-like enum variant.
681774
recordVariantConstruction(re.getRecordExpr(),
682-
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
683-
node1.asExpr() = re.getFieldExpr(field)
684-
|
685-
re
686-
)
775+
c.(VariantFieldContent).getVariantCanonicalPath(field))
776+
or
777+
// Expression is for a struct.
778+
structConstruction(re.getRecordExpr(),
779+
c.(StructFieldContent).getStructCanonicalPath(field))
780+
) and
781+
node1.asExpr() = re.getFieldExpr(field) and
782+
node2.asExpr() = re
783+
)
784+
or
785+
exists(TupleExprCfgNode tuple |
786+
node1.asExpr() = tuple.getField(c.(TuplePositionContent).getPosition()) and
787+
node2.asExpr() = tuple
788+
)
789+
or
790+
tupleAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
687791
)
688792
}
689793

@@ -692,7 +796,9 @@ module RustDataFlow implements InputSig<Location> {
692796
* any value stored inside `f` is cleared at the pre-update node associated with `x`
693797
* in `x.f = newValue`.
694798
*/
695-
predicate clearsContent(Node n, ContentSet c) { none() }
799+
predicate clearsContent(Node n, ContentSet cs) {
800+
tupleAssignment(_, n, cs.(SingletonContentSet).getContent())
801+
}
696802

697803
/**
698804
* Holds if the value that is being tracked is expected to be stored inside content `c`
@@ -762,7 +868,9 @@ private module Cached {
762868
TExprNode(ExprCfgNode n) or
763869
TParameterNode(ParamBaseCfgNode p) or
764870
TPatNode(PatCfgNode p) or
765-
TArgumentPostUpdateNode(ExprCfgNode e) { isArgumentForCall(e, _, _) } or
871+
TExprPostUpdateNode(ExprCfgNode e) {
872+
isArgumentForCall(e, _, _) or e = any(FieldExprCfgNode access).getExpr()
873+
} or
766874
TSsaNode(SsaImpl::DataFlowIntegration::SsaNode node)
767875

768876
cached
@@ -800,6 +908,12 @@ private module Cached {
800908
name = ["Ok", "Err"]
801909
}
802910

911+
cached
912+
newtype TStructCanonicalPath =
913+
MkStructCanonicalPath(CrateOriginOption crate, string path) {
914+
exists(Struct s | hasExtendedCanonicalPath(s, crate, path))
915+
}
916+
803917
cached
804918
newtype TContent =
805919
TVariantPositionContent(VariantCanonicalPath v, int pos) {
@@ -815,6 +929,16 @@ private module Cached {
815929
} or
816930
TVariantFieldContent(VariantCanonicalPath v, string field) {
817931
field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText()
932+
} or
933+
TTuplePositionContent(int pos) {
934+
pos in [0 .. max([
935+
any(TuplePat pat).getNumberOfFields(),
936+
any(FieldExpr access).getNameRef().getText().toInt()
937+
]
938+
)]
939+
} or
940+
TStructFieldContent(StructCanonicalPath s, string field) {
941+
field = s.getStruct().getFieldList().(RecordFieldList).getAField().getName().getText()
818942
}
819943

820944
cached

0 commit comments

Comments
 (0)