Skip to content

Commit 4f1bd4e

Browse files
committed
Rename compare_ops and add test for ifelse
1 parent 535336c commit 4f1bd4e

File tree

6 files changed

+50
-43
lines changed

6 files changed

+50
-43
lines changed

include/pyoptinterface/nlexpr.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ enum class BinaryOperator
6868
Pow,
6969

7070
// compare
71-
Lessthan,
72-
Lessequal,
71+
LessThan,
72+
LessEqual,
7373
Equal,
74-
Notequal,
75-
Greaterequal,
76-
Greaterthan,
74+
NotEqual,
75+
GreaterEqual,
76+
GreaterThan,
7777
};
7878

7979
bool is_binary_compare_op(BinaryOperator op);

lib/cppad_interface.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,12 @@ CppAD::AD<double> cppad_build_binary_expression(BinaryOperator op, const CppAD::
262262
case BinaryOperator::Pow: {
263263
return CppAD::pow(left, right);
264264
}
265-
case BinaryOperator::Lessthan:
266-
case BinaryOperator::Lessequal:
265+
case BinaryOperator::LessThan:
266+
case BinaryOperator::LessEqual:
267267
case BinaryOperator::Equal:
268-
case BinaryOperator::Notequal:
269-
case BinaryOperator::Greaterequal:
270-
case BinaryOperator::Greaterthan: {
268+
case BinaryOperator::NotEqual:
269+
case BinaryOperator::GreaterEqual:
270+
case BinaryOperator::GreaterThan: {
271271
throw std::runtime_error("Currently comparision operator can only be used with ifelse "
272272
"function and cannot be evaluated as value");
273273
}
@@ -285,22 +285,22 @@ CppAD::AD<double> cppad_build_ternary_expression(BinaryOperator compare_op,
285285
{
286286
switch (compare_op)
287287
{
288-
case BinaryOperator::Lessthan: {
288+
case BinaryOperator::LessThan: {
289289
return CppAD::CondExpLt(compare_left, compare_right, then_result, else_result);
290290
}
291-
case BinaryOperator::Lessequal: {
291+
case BinaryOperator::LessEqual: {
292292
return CppAD::CondExpLe(compare_left, compare_right, then_result, else_result);
293293
}
294294
case BinaryOperator::Equal: {
295295
return CppAD::CondExpEq(compare_left, compare_right, then_result, else_result);
296296
}
297-
case BinaryOperator::Notequal: {
297+
case BinaryOperator::NotEqual: {
298298
return CppAD::CondExpEq(compare_left, compare_right, else_result, then_result);
299299
}
300-
case BinaryOperator::Greaterequal: {
300+
case BinaryOperator::GreaterEqual: {
301301
return CppAD::CondExpGe(compare_left, compare_right, then_result, else_result);
302302
}
303-
case BinaryOperator::Greaterthan: {
303+
case BinaryOperator::GreaterThan: {
304304
return CppAD::CondExpGt(compare_left, compare_right, then_result, else_result);
305305
}
306306
default: {

lib/nlexpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ NaryOperator ExpressionGraph::get_nary_operator(const ExpressionHandle &expressi
8383

8484
bool is_binary_compare_op(BinaryOperator op)
8585
{
86-
return (op >= BinaryOperator::Lessthan) && (op <= BinaryOperator::Greaterthan);
86+
return (op >= BinaryOperator::LessThan) && (op <= BinaryOperator::GreaterThan);
8787
}

lib/nlexpr_ext.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ NB_MODULE(nlexpr_ext, m)
4646
.value("Div", BinaryOperator::Div)
4747
.value("Pow", BinaryOperator::Pow)
4848
// compare ops
49-
.value("Lessthan", BinaryOperator::Lessthan)
50-
.value("Lessequal", BinaryOperator::Lessequal)
49+
.value("LessThan", BinaryOperator::LessThan)
50+
.value("LessEqual", BinaryOperator::LessEqual)
5151
.value("Equal", BinaryOperator::Equal)
52-
.value("Notequal", BinaryOperator::Notequal)
53-
.value("Greaterequal", BinaryOperator::Greaterequal)
54-
.value("Greaterthan", BinaryOperator::Greaterthan);
52+
.value("NotEqual", BinaryOperator::NotEqual)
53+
.value("GreaterEqual", BinaryOperator::GreaterEqual)
54+
.value("GreaterThan", BinaryOperator::GreaterThan);
5555

5656
nb::enum_<TernaryOperator>(m, "TernaryOperator")
5757
.value("IfThenElse", TernaryOperator::IfThenElse);

src/pyoptinterface/_src/function_tracing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def pow_int(graph, expr, N):
2929

3030
M, r = divmod(N, 2)
3131

32-
pow_2 = graph.add_binary(BinaryOperator.Mul, [expr, expr])
32+
pow_2 = graph.add_nary(NaryOperator.Mul, [expr, expr])
3333
pow_2M = pow_int(graph, pow_2, M)
3434

3535
if r == 0:
3636
return pow_2M
3737
else:
38-
return graph.add_binary(BinaryOperator.Mul, [pow_2M, expr])
38+
return graph.add_nary(NaryOperator.Mul, [pow_2M, expr])
3939

4040

4141
@dataclass

tests/test_nlp.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,6 @@ def con(vars):
4949
nlfunc.sqrt,
5050
nlfunc.tan,
5151
]
52-
py_funcs = [
53-
abs,
54-
math.acos,
55-
math.asin,
56-
math.atan,
57-
math.cos,
58-
math.exp,
59-
math.log,
60-
math.pow,
61-
math.sin,
62-
math.sqrt,
63-
math.tan,
64-
]
6552

6653
def all_nlfuncs(vars):
6754
x = vars.x
@@ -102,13 +89,9 @@ def all_nlfuncs(vars):
10289
all_nlfuncs_con, poi.ConstraintAttribute.Primal
10390
)
10491

105-
correct_con_values = []
106-
for f in py_funcs:
107-
if f == math.pow:
108-
v = f(x_value, 2)
109-
else:
110-
v = f(x_value)
111-
correct_con_values.append(v)
92+
vars = nlfunc.Vars(x=x_value)
93+
correct_con_values = all_nlfuncs(vars)
94+
11295
assert con_values == pytest.approx(correct_con_values)
11396

11497

@@ -157,6 +140,30 @@ def con(vars, params):
157140
assert x_values == pytest.approx(correct_x_values)
158141

159142

143+
def test_nlfunc_ifelse():
144+
if not ipopt.is_library_loaded():
145+
pytest.skip("Ipopt library is not loaded")
146+
147+
for x_, fx in zip([0.2, 0.5, 1.0, 2.0, 3.0], [0.2, 0.5, 1.0, 4.0, 9.0]):
148+
model = ipopt.Model()
149+
150+
x = model.add_variable(lb=0.0, ub=10.0, start=1.0)
151+
152+
def con(vars):
153+
x = vars.x
154+
return nlfunc.ifelse(x > 1.0, x**2, x)
155+
156+
con_f = model.register_function(con)
157+
model.add_nl_constraint(con_f, vars=nlfunc.Vars(x=x), lb=[fx])
158+
159+
model.set_objective(x)
160+
161+
model.optimize()
162+
163+
x_value = model.get_value(x)
164+
assert x_value == pytest.approx(x_)
165+
166+
160167
if __name__ == "__main__":
161168
test_ipopt()
162169
test_nlp_param()

0 commit comments

Comments
 (0)