Skip to content

Commit 597f008

Browse files
authored
Merge pull request github#15958 from MathiasVP/ir-guards-from-switch-statements-2
C++: Implement guards logic for switch statements
2 parents df18453 + d7afd7b commit 597f008

File tree

8 files changed

+411
-66
lines changed

8 files changed

+411
-66
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
category: feature
3+
---
4+
* Added a predicate `GuardCondition.comparesEq/4` to query whether an expression is compared to a constant.
5+
* Added a predicate `GuardCondition.ensuresEq/4` to query whether a basic block is guarded by an expression being equal to a constant.

cpp/ql/lib/semmle/code/cpp/controlflow/IRGuards.qll

Lines changed: 199 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ class GuardCondition extends Expr {
137137
*/
138138
cached
139139
predicate ensuresEq(Expr left, Expr right, int k, BasicBlock block, boolean areEqual) { none() }
140+
141+
/** Holds if (determined by this guard) `e == k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
142+
cached
143+
predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) { none() }
144+
145+
/**
146+
* Holds if (determined by this guard) `e == k` must be `areEqual` in `block`.
147+
* If `areEqual = false` then this implies `e != k`.
148+
*/
149+
cached
150+
predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) { none() }
140151
}
141152

142153
/**
@@ -184,6 +195,20 @@ private class GuardConditionFromBinaryLogicalOperator extends GuardCondition {
184195
this.comparesEq(left, right, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
185196
)
186197
}
198+
199+
override predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) {
200+
exists(boolean partIsTrue, GuardCondition part |
201+
this.(BinaryLogicalOperation).impliesValue(part, partIsTrue, testIsTrue)
202+
|
203+
part.comparesEq(e, k, areEqual, partIsTrue)
204+
)
205+
}
206+
207+
override predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) {
208+
exists(boolean testIsTrue |
209+
this.comparesEq(e, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
210+
)
211+
}
187212
}
188213

189214
/**
@@ -245,6 +270,21 @@ private class GuardConditionFromIR extends GuardCondition {
245270
)
246271
}
247272

273+
override predicate comparesEq(Expr e, int k, boolean areEqual, boolean testIsTrue) {
274+
exists(Instruction i |
275+
i.getUnconvertedResultExpression() = e and
276+
ir.comparesEq(i.getAUse(), k, areEqual, testIsTrue)
277+
)
278+
}
279+
280+
override predicate ensuresEq(Expr e, int k, BasicBlock block, boolean areEqual) {
281+
exists(Instruction i, boolean testIsTrue |
282+
i.getUnconvertedResultExpression() = e and
283+
ir.comparesEq(i.getAUse(), k, areEqual, testIsTrue) and
284+
this.controls(block, testIsTrue)
285+
)
286+
}
287+
248288
/**
249289
* Holds if this condition controls `block`, meaning that `block` is only
250290
* entered if the value of this condition is `v`. This helper
@@ -446,7 +486,25 @@ class IRGuardCondition extends Instruction {
446486
/** Holds if (determined by this guard) `left == right + k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
447487
cached
448488
predicate comparesEq(Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue) {
449-
compares_eq(this, left, right, k, areEqual, testIsTrue)
489+
exists(BooleanValue value |
490+
compares_eq(this, left, right, k, areEqual, value) and
491+
value.getValue() = testIsTrue
492+
)
493+
}
494+
495+
/** Holds if (determined by this guard) `op == k` evaluates to `areEqual` if this expression evaluates to `testIsTrue`. */
496+
cached
497+
predicate comparesEq(Operand op, int k, boolean areEqual, boolean testIsTrue) {
498+
exists(MatchValue mv |
499+
compares_eq(this, op, k, areEqual, mv) and
500+
// A match value cannot be dualized, so `testIsTrue` is always true
501+
testIsTrue = true
502+
)
503+
or
504+
exists(BooleanValue bv |
505+
compares_eq(this, op, k, areEqual, bv) and
506+
bv.getValue() = testIsTrue
507+
)
450508
}
451509

452510
/**
@@ -455,8 +513,19 @@ class IRGuardCondition extends Instruction {
455513
*/
456514
cached
457515
predicate ensuresEq(Operand left, Operand right, int k, IRBlock block, boolean areEqual) {
458-
exists(boolean testIsTrue |
459-
compares_eq(this, left, right, k, areEqual, testIsTrue) and this.controls(block, testIsTrue)
516+
exists(AbstractValue value |
517+
compares_eq(this, left, right, k, areEqual, value) and this.valueControls(block, value)
518+
)
519+
}
520+
521+
/**
522+
* Holds if (determined by this guard) `op == k` must be `areEqual` in `block`.
523+
* If `areEqual = false` then this implies `op != k`.
524+
*/
525+
cached
526+
predicate ensuresEq(Operand op, int k, IRBlock block, boolean areEqual) {
527+
exists(AbstractValue value |
528+
compares_eq(this, op, k, areEqual, value) and this.valueControls(block, value)
460529
)
461530
}
462531

@@ -468,9 +537,21 @@ class IRGuardCondition extends Instruction {
468537
predicate ensuresEqEdge(
469538
Operand left, Operand right, int k, IRBlock pred, IRBlock succ, boolean areEqual
470539
) {
471-
exists(boolean testIsTrue |
472-
compares_eq(this, left, right, k, areEqual, testIsTrue) and
473-
this.controlsEdge(pred, succ, testIsTrue)
540+
exists(AbstractValue value |
541+
compares_eq(this, left, right, k, areEqual, value) and
542+
this.valueControlsEdge(pred, succ, value)
543+
)
544+
}
545+
546+
/**
547+
* Holds if (determined by this guard) `op == k` must be `areEqual` on the edge from
548+
* `pred` to `succ`. If `areEqual = false` then this implies `op != k`.
549+
*/
550+
cached
551+
predicate ensuresEqEdge(Operand op, int k, IRBlock pred, IRBlock succ, boolean areEqual) {
552+
exists(AbstractValue value |
553+
compares_eq(this, op, k, areEqual, value) and
554+
this.valueControlsEdge(pred, succ, value)
474555
)
475556
}
476557

@@ -572,52 +653,98 @@ private Instruction getBranchForCondition(Instruction guard) {
572653
* Beware making mistaken logical implications here relating `areEqual` and `testIsTrue`.
573654
*/
574655
private predicate compares_eq(
575-
Instruction test, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
656+
Instruction test, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
576657
) {
577658
/* The simple case where the test *is* the comparison so areEqual = testIsTrue xor eq. */
578-
exists(boolean eq | simple_comparison_eq(test, left, right, k, eq) |
579-
areEqual = true and testIsTrue = eq
659+
exists(AbstractValue v | simple_comparison_eq(test, left, right, k, v) |
660+
areEqual = true and value = v
580661
or
581-
areEqual = false and testIsTrue = eq.booleanNot()
662+
areEqual = false and value = v.getDualValue()
582663
)
583664
or
584665
// I think this is handled by forwarding in controlsBlock.
585666
//or
586667
//logical_comparison_eq(test, left, right, k, areEqual, testIsTrue)
587668
/* a == b + k => b == a - k */
588-
exists(int mk | k = -mk | compares_eq(test, right, left, mk, areEqual, testIsTrue))
669+
exists(int mk | k = -mk | compares_eq(test, right, left, mk, areEqual, value))
589670
or
590-
complex_eq(test, left, right, k, areEqual, testIsTrue)
671+
complex_eq(test, left, right, k, areEqual, value)
591672
or
592673
/* (x is true => (left == right + k)) => (!x is false => (left == right + k)) */
593-
exists(boolean isFalse | testIsTrue = isFalse.booleanNot() |
594-
compares_eq(test.(LogicalNotInstruction).getUnary(), left, right, k, areEqual, isFalse)
674+
exists(AbstractValue dual | value = dual.getDualValue() |
675+
compares_eq(test.(LogicalNotInstruction).getUnary(), left, right, k, areEqual, dual)
676+
)
677+
}
678+
679+
/** Holds if `op == k` is `areEqual` given that `test` is equal to `value`. */
680+
private predicate compares_eq(
681+
Instruction test, Operand op, int k, boolean areEqual, AbstractValue value
682+
) {
683+
/* The simple case where the test *is* the comparison so areEqual = testIsTrue xor eq. */
684+
exists(AbstractValue v | simple_comparison_eq(test, op, k, v) |
685+
areEqual = true and value = v
686+
or
687+
areEqual = false and value = v.getDualValue()
688+
)
689+
or
690+
complex_eq(test, op, k, areEqual, value)
691+
or
692+
/* (x is true => (op == k)) => (!x is false => (op == k)) */
693+
exists(AbstractValue dual | value = dual.getDualValue() |
694+
compares_eq(test.(LogicalNotInstruction).getUnary(), op, k, areEqual, dual)
695+
)
696+
or
697+
// ((test is `areEqual` => op == const + k2) and const == `k1`) =>
698+
// test is `areEqual` => op == k1 + k2
699+
exists(int k1, int k2, ConstantInstruction const |
700+
compares_eq(test, op, const.getAUse(), k2, areEqual, value) and
701+
int_value(const) = k1 and
702+
k = k1 + k2
595703
)
596704
}
597705

598706
/** Rearrange various simple comparisons into `left == right + k` form. */
599707
private predicate simple_comparison_eq(
600-
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual
708+
CompareInstruction cmp, Operand left, Operand right, int k, AbstractValue value
601709
) {
602710
left = cmp.getLeftOperand() and
603711
cmp instanceof CompareEQInstruction and
604712
right = cmp.getRightOperand() and
605713
k = 0 and
606-
areEqual = true
714+
value.(BooleanValue).getValue() = true
607715
or
608716
left = cmp.getLeftOperand() and
609717
cmp instanceof CompareNEInstruction and
610718
right = cmp.getRightOperand() and
611719
k = 0 and
612-
areEqual = false
720+
value.(BooleanValue).getValue() = false
721+
}
722+
723+
/** Rearrange various simple comparisons into `op == k` form. */
724+
private predicate simple_comparison_eq(Instruction test, Operand op, int k, AbstractValue value) {
725+
exists(SwitchInstruction switch, CaseEdge case |
726+
test = switch.getExpression() and
727+
op.getDef() = test and
728+
case = value.(MatchValue).getCase() and
729+
exists(switch.getSuccessor(case)) and
730+
case.getValue().toInt() = k
731+
)
732+
}
733+
734+
private predicate complex_eq(
735+
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
736+
) {
737+
sub_eq(cmp, left, right, k, areEqual, value)
738+
or
739+
add_eq(cmp, left, right, k, areEqual, value)
613740
}
614741

615742
private predicate complex_eq(
616-
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
743+
Instruction test, Operand op, int k, boolean areEqual, AbstractValue value
617744
) {
618-
sub_eq(cmp, left, right, k, areEqual, testIsTrue)
745+
sub_eq(test, op, k, areEqual, value)
619746
or
620-
add_eq(cmp, left, right, k, areEqual, testIsTrue)
747+
add_eq(test, op, k, areEqual, value)
621748
}
622749

623750
/*
@@ -768,44 +895,61 @@ private predicate add_lt(
768895
// left - x == right + c => left == right + (c+x)
769896
// left == (right - x) + c => left == right + (c-x)
770897
private predicate sub_eq(
771-
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
898+
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
772899
) {
773900
exists(SubInstruction lhs, int c, int x |
774-
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
901+
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
775902
left = lhs.getLeftOperand() and
776903
x = int_value(lhs.getRight()) and
777904
k = c + x
778905
)
779906
or
780907
exists(SubInstruction rhs, int c, int x |
781-
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
908+
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
782909
right = rhs.getLeftOperand() and
783910
x = int_value(rhs.getRight()) and
784911
k = c - x
785912
)
786913
or
787914
exists(PointerSubInstruction lhs, int c, int x |
788-
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
915+
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
789916
left = lhs.getLeftOperand() and
790917
x = int_value(lhs.getRight()) and
791918
k = c + x
792919
)
793920
or
794921
exists(PointerSubInstruction rhs, int c, int x |
795-
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
922+
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
796923
right = rhs.getLeftOperand() and
797924
x = int_value(rhs.getRight()) and
798925
k = c - x
799926
)
800927
}
801928

929+
// op - x == c => op == (c+x)
930+
private predicate sub_eq(Instruction test, Operand op, int k, boolean areEqual, AbstractValue value) {
931+
exists(SubInstruction sub, int c, int x |
932+
compares_eq(test, sub.getAUse(), c, areEqual, value) and
933+
op = sub.getLeftOperand() and
934+
x = int_value(sub.getRight()) and
935+
k = c + x
936+
)
937+
or
938+
exists(PointerSubInstruction sub, int c, int x |
939+
compares_eq(test, sub.getAUse(), c, areEqual, value) and
940+
op = sub.getLeftOperand() and
941+
x = int_value(sub.getRight()) and
942+
k = c + x
943+
)
944+
}
945+
802946
// left + x == right + c => left == right + (c-x)
803947
// left == (right + x) + c => left == right + (c+x)
804948
private predicate add_eq(
805-
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, boolean testIsTrue
949+
CompareInstruction cmp, Operand left, Operand right, int k, boolean areEqual, AbstractValue value
806950
) {
807951
exists(AddInstruction lhs, int c, int x |
808-
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
952+
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
809953
(
810954
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
811955
or
@@ -815,7 +959,7 @@ private predicate add_eq(
815959
)
816960
or
817961
exists(AddInstruction rhs, int c, int x |
818-
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
962+
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
819963
(
820964
right = rhs.getLeftOperand() and x = int_value(rhs.getRight())
821965
or
@@ -825,7 +969,7 @@ private predicate add_eq(
825969
)
826970
or
827971
exists(PointerAddInstruction lhs, int c, int x |
828-
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, testIsTrue) and
972+
compares_eq(cmp, lhs.getAUse(), right, c, areEqual, value) and
829973
(
830974
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
831975
or
@@ -835,7 +979,7 @@ private predicate add_eq(
835979
)
836980
or
837981
exists(PointerAddInstruction rhs, int c, int x |
838-
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, testIsTrue) and
982+
compares_eq(cmp, left, rhs.getAUse(), c, areEqual, value) and
839983
(
840984
right = rhs.getLeftOperand() and x = int_value(rhs.getRight())
841985
or
@@ -845,5 +989,30 @@ private predicate add_eq(
845989
)
846990
}
847991

992+
// left + x == right + c => left == right + (c-x)
993+
private predicate add_eq(
994+
Instruction test, Operand left, int k, boolean areEqual, AbstractValue value
995+
) {
996+
exists(AddInstruction lhs, int c, int x |
997+
compares_eq(test, lhs.getAUse(), c, areEqual, value) and
998+
(
999+
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
1000+
or
1001+
left = lhs.getRightOperand() and x = int_value(lhs.getLeft())
1002+
) and
1003+
k = c - x
1004+
)
1005+
or
1006+
exists(PointerAddInstruction lhs, int c, int x |
1007+
compares_eq(test, lhs.getAUse(), c, areEqual, value) and
1008+
(
1009+
left = lhs.getLeftOperand() and x = int_value(lhs.getRight())
1010+
or
1011+
left = lhs.getRightOperand() and x = int_value(lhs.getLeft())
1012+
) and
1013+
k = c - x
1014+
)
1015+
}
1016+
8481017
/** The int value of integer constant expression. */
8491018
private int int_value(Instruction i) { result = i.(IntegerConstantInstruction).getValue().toInt() }

cpp/ql/lib/semmle/code/cpp/ir/implementation/EdgeKind.qll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ class CaseEdge extends EdgeKind, TCaseEdge {
9090
* Gets the largest value of the switch expression for which control will flow along this edge.
9191
*/
9292
final string getMaxValue() { result = maxValue }
93+
94+
/**
95+
* Gets the unique value of the switch expression for which control will
96+
* flow along this edge, if any.
97+
*/
98+
final string getValue() {
99+
minValue = maxValue and
100+
result = minValue
101+
}
93102
}
94103

95104
/**

0 commit comments

Comments
 (0)