Skip to content

Commit 6f35be4

Browse files
committed
Add case_exprt to std_expr.h and refactor code to use it
Introduces `case_exprt` for improved type safety. Fixes: #3037
1 parent 4fe3ade commit 6f35be4

File tree

6 files changed

+278
-8
lines changed

6 files changed

+278
-8
lines changed

src/solvers/flattening/boolbv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
123123
else if(expr.id() == ID_update_bit)
124124
return convert_update_bit(to_update_bit_expr(expr));
125125
else if(expr.id()==ID_case)
126-
return convert_case(expr);
126+
return convert_case(to_case_expr(expr));
127127
else if(expr.id()==ID_cond)
128128
return convert_cond(to_cond_expr(expr));
129129
else if(expr.id()==ID_if)
@@ -390,7 +390,7 @@ literalt boolbvt::convert_rest(const exprt &expr)
390390
}
391391
else if(expr.id()==ID_case)
392392
{
393-
bvt bv=convert_case(expr);
393+
bvt bv = convert_case(to_case_expr(expr));
394394
CHECK_RETURN(bv.size() == 1);
395395
return bv[0];
396396
}

src/solvers/flattening/boolbv.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class boolbvt:public arrayst
185185
virtual bvt convert_update(const update_exprt &);
186186
virtual bvt convert_update_bit(const update_bit_exprt &);
187187
virtual bvt convert_update_bits(const update_bits_exprt &);
188-
virtual bvt convert_case(const exprt &expr);
188+
virtual bvt convert_case(const case_exprt &expr);
189189
virtual bvt convert_cond(const cond_exprt &);
190190
virtual bvt convert_shift(const binary_exprt &expr);
191191
virtual bvt convert_bitwise(const exprt &expr);

src/solvers/flattening/boolbv_case.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@ Author: Daniel Kroening, [email protected]
66
77
\*******************************************************************/
88

9-
#include "boolbv.h"
10-
119
#include <util/invariant.h>
10+
#include <util/std_expr.h>
1211

13-
bvt boolbvt::convert_case(const exprt &expr)
14-
{
15-
PRECONDITION(expr.id() == ID_case);
12+
#include "boolbv.h"
1613

14+
bvt boolbvt::convert_case(const case_exprt &expr)
15+
{
1716
const std::vector<exprt> &operands=expr.operands();
1817

1918
std::size_t width=boolbv_width(expr.type());

src/util/std_expr.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3578,6 +3578,123 @@ inline cond_exprt &to_cond_expr(exprt &expr)
35783578
return ret;
35793579
}
35803580

3581+
/// \brief Case expression: evaluates to the value corresponding to the first
3582+
/// matching case. The first operand is the value to compare against. Subsequent
3583+
/// operands alternate between compare values and result values.
3584+
/// The syntax is: case(select_value, case1_value, result1, case2_value, result2, ...)
3585+
class case_exprt : public multi_ary_exprt
3586+
{
3587+
public:
3588+
case_exprt(operandst _operands, typet _type)
3589+
: multi_ary_exprt(ID_case, std::move(_operands), std::move(_type))
3590+
{
3591+
}
3592+
3593+
/// Constructor with select value
3594+
case_exprt(exprt _select_value, typet _type)
3595+
: multi_ary_exprt(ID_case, {std::move(_select_value)}, std::move(_type))
3596+
{
3597+
}
3598+
3599+
/// Get the value that is being compared against
3600+
const exprt &select_value() const
3601+
{
3602+
PRECONDITION(!operands().empty());
3603+
return operands()[0];
3604+
}
3605+
3606+
/// Get the value that is being compared against
3607+
exprt &select_value()
3608+
{
3609+
PRECONDITION(!operands().empty());
3610+
return operands()[0];
3611+
}
3612+
3613+
/// Add a case: value to compare and corresponding result
3614+
/// \param case_value: the value to compare against select_value
3615+
/// \param result_value: the value to return if case_value matches select_value
3616+
void add_case(const exprt &case_value, const exprt &result_value)
3617+
{
3618+
operands().reserve(operands().size() + 2);
3619+
operands().push_back(case_value);
3620+
operands().push_back(result_value);
3621+
}
3622+
3623+
/// Get the number of cases (excluding the select value)
3624+
std::size_t number_of_cases() const
3625+
{
3626+
PRECONDITION(operands().size() >= 1);
3627+
return (operands().size() - 1) / 2;
3628+
}
3629+
3630+
/// Get the case value for the i-th case
3631+
const exprt &case_value(std::size_t i) const
3632+
{
3633+
PRECONDITION(i < number_of_cases());
3634+
return operands()[1 + 2 * i];
3635+
}
3636+
3637+
/// Get the case value for the i-th case
3638+
exprt &case_value(std::size_t i)
3639+
{
3640+
PRECONDITION(i < number_of_cases());
3641+
return operands()[1 + 2 * i];
3642+
}
3643+
3644+
/// Get the result value for the i-th case
3645+
const exprt &result_value(std::size_t i) const
3646+
{
3647+
PRECONDITION(i < number_of_cases());
3648+
return operands()[1 + 2 * i + 1];
3649+
}
3650+
3651+
/// Get the result value for the i-th case
3652+
exprt &result_value(std::size_t i)
3653+
{
3654+
PRECONDITION(i < number_of_cases());
3655+
return operands()[1 + 2 * i + 1];
3656+
}
3657+
};
3658+
3659+
template <>
3660+
inline bool can_cast_expr<case_exprt>(const exprt &base)
3661+
{
3662+
return base.id() == ID_case;
3663+
}
3664+
3665+
inline void validate_expr(const case_exprt &value)
3666+
{
3667+
DATA_INVARIANT(
3668+
value.operands().size() >= 1,
3669+
"case expression must have at least one operand");
3670+
DATA_INVARIANT(
3671+
value.operands().size() % 2 == 1,
3672+
"case expression must have odd number of operands");
3673+
}
3674+
3675+
/// \brief Cast an exprt to a \ref case_exprt
3676+
///
3677+
/// \a expr must be known to be \ref case_exprt.
3678+
///
3679+
/// \param expr: Source expression
3680+
/// \return Object of type \ref case_exprt
3681+
inline const case_exprt &to_case_expr(const exprt &expr)
3682+
{
3683+
PRECONDITION(expr.id() == ID_case);
3684+
const case_exprt &ret = static_cast<const case_exprt &>(expr);
3685+
validate_expr(ret);
3686+
return ret;
3687+
}
3688+
3689+
/// \copydoc to_case_expr(const exprt &)
3690+
inline case_exprt &to_case_expr(exprt &expr)
3691+
{
3692+
PRECONDITION(expr.id() == ID_case);
3693+
case_exprt &ret = static_cast<case_exprt &>(expr);
3694+
validate_expr(ret);
3695+
return ret;
3696+
}
3697+
35813698
/// \brief Expression to define a mapping from an argument (index) to elements.
35823699
/// This enables constructing an array via an anonymous function.
35833700
/// Not all kinds of array comprehension can be expressed, only those of the

test_case_expr_example.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*******************************************************************\
2+
3+
Example: Demonstrating the new case_exprt usage
4+
5+
Author: Example
6+
7+
\*******************************************************************/
8+
9+
#include <util/arith_tools.h>
10+
#include <util/bitvector_types.h>
11+
#include <util/std_expr.h>
12+
13+
#include <iostream>
14+
15+
int main()
16+
{
17+
// Create integer type
18+
const signedbv_typet int_type(32);
19+
20+
// Create select value (the variable being matched)
21+
const symbol_exprt x("x", int_type);
22+
23+
// Create a case expression: case(x, 1, 10, 2, 20, 3, 30)
24+
// This means: if x==1 return 10, if x==2 return 20, if x==3 return 30
25+
case_exprt case_expr(x, int_type);
26+
27+
// Add cases
28+
case_expr.add_case(from_integer(1, int_type), from_integer(10, int_type));
29+
case_expr.add_case(from_integer(2, int_type), from_integer(20, int_type));
30+
case_expr.add_case(from_integer(3, int_type), from_integer(30, int_type));
31+
32+
// Demonstrate usage
33+
std::cout << "Expression ID: " << case_expr.id() << std::endl;
34+
std::cout << "Number of cases: " << case_expr.number_of_cases() << std::endl;
35+
std::cout << "Total operands: " << case_expr.operands().size() << std::endl;
36+
37+
// Show each case
38+
for(std::size_t i = 0; i < case_expr.number_of_cases(); ++i)
39+
{
40+
auto case_val = to_constant_expr(case_expr.case_value(i));
41+
auto result_val = to_constant_expr(case_expr.result_value(i));
42+
std::cout << "Case " << i << ": when x==" << case_val.get_value()
43+
<< " return " << result_val.get_value() << std::endl;
44+
}
45+
46+
// Demonstrate conversion
47+
exprt &base_expr = case_expr;
48+
if(can_cast_expr<case_exprt>(base_expr))
49+
{
50+
case_exprt &converted = to_case_expr(base_expr);
51+
std::cout << "Successfully converted back to case_exprt" << std::endl;
52+
std::cout << "Converted has " << converted.number_of_cases() << " cases"
53+
<< std::endl;
54+
}
55+
56+
return 0;
57+
}

unit/util/case_expr.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*******************************************************************\
2+
3+
Module: Unit tests for case_exprt
4+
5+
Author: Unit test
6+
7+
\*******************************************************************/
8+
9+
#include <util/arith_tools.h>
10+
#include <util/bitvector_types.h>
11+
#include <util/std_expr.h>
12+
13+
#include <testing-utils/use_catch.h>
14+
15+
TEST_CASE("case_exprt construction and access", "[core][util][case_expr]")
16+
{
17+
const signedbv_typet int_type(32);
18+
const symbol_exprt select_value("x", int_type);
19+
20+
SECTION("Basic construction")
21+
{
22+
case_exprt case_expr(select_value, int_type);
23+
24+
REQUIRE(case_expr.id() == ID_case);
25+
REQUIRE(case_expr.select_value() == select_value);
26+
REQUIRE(case_expr.number_of_cases() == 0);
27+
}
28+
29+
SECTION("Adding cases")
30+
{
31+
case_exprt case_expr(select_value, int_type);
32+
33+
const constant_exprt case1_value = from_integer(1, int_type);
34+
const constant_exprt result1_value = from_integer(10, int_type);
35+
36+
const constant_exprt case2_value = from_integer(2, int_type);
37+
const constant_exprt result2_value = from_integer(20, int_type);
38+
39+
case_expr.add_case(case1_value, result1_value);
40+
REQUIRE(case_expr.number_of_cases() == 1);
41+
REQUIRE(case_expr.case_value(0) == case1_value);
42+
REQUIRE(case_expr.result_value(0) == result1_value);
43+
44+
case_expr.add_case(case2_value, result2_value);
45+
REQUIRE(case_expr.number_of_cases() == 2);
46+
REQUIRE(case_expr.case_value(1) == case2_value);
47+
REQUIRE(case_expr.result_value(1) == result2_value);
48+
49+
// Verify operands structure: 1 select + 2*2 case/result pairs = 5
50+
REQUIRE(case_expr.operands().size() == 5);
51+
// Verify odd number of operands
52+
REQUIRE(case_expr.operands().size() % 2 == 1);
53+
}
54+
55+
SECTION("to_case_expr conversion")
56+
{
57+
case_exprt case_expr(select_value, int_type);
58+
const constant_exprt case_value = from_integer(1, int_type);
59+
const constant_exprt result_value = from_integer(10, int_type);
60+
case_expr.add_case(case_value, result_value);
61+
62+
exprt &base = case_expr;
63+
case_exprt &converted = to_case_expr(base);
64+
65+
REQUIRE(&converted == &case_expr);
66+
REQUIRE(converted.number_of_cases() == 1);
67+
REQUIRE(converted.case_value(0) == case_value);
68+
}
69+
70+
SECTION("can_cast_expr")
71+
{
72+
case_exprt case_expr(select_value, int_type);
73+
exprt &base = case_expr;
74+
75+
REQUIRE(can_cast_expr<case_exprt>(base));
76+
REQUIRE_FALSE(can_cast_expr<if_exprt>(base));
77+
}
78+
79+
SECTION("Construction with operands")
80+
{
81+
const constant_exprt case_value = from_integer(1, int_type);
82+
const constant_exprt result_value = from_integer(10, int_type);
83+
84+
case_exprt::operandst ops;
85+
ops.push_back(select_value);
86+
ops.push_back(case_value);
87+
ops.push_back(result_value);
88+
89+
case_exprt case_expr(std::move(ops), int_type);
90+
91+
REQUIRE(case_expr.id() == ID_case);
92+
REQUIRE(case_expr.number_of_cases() == 1);
93+
REQUIRE(case_expr.select_value() == select_value);
94+
REQUIRE(case_expr.case_value(0) == case_value);
95+
REQUIRE(case_expr.result_value(0) == result_value);
96+
}
97+
}

0 commit comments

Comments
 (0)