Skip to content

Commit 791ae22

Browse files
committed
finish merge
2 parents 760c522 + acb21dc commit 791ae22

File tree

15 files changed

+357
-14
lines changed

15 files changed

+357
-14
lines changed

cvxpy/atoms/elementwise/entr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,6 @@ def _domain(self) -> List[Constraint]:
9393
"""Returns constraints describing the domain of the node.
9494
"""
9595
return [self.args[0] >= 0]
96+
97+
def point_in_domain(self):
98+
return np.ones(self.shape)

cvxpy/atoms/elementwise/kl_div.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,6 @@ def _domain(self) -> List[Constraint]:
9494
"""Returns constraints describing the domain of the node.
9595
"""
9696
return [self.args[0] >= 0, self.args[1] >= 0]
97+
98+
def point_in_domain(self, argument=0):
99+
return np.ones(self.args[argument].shape)

cvxpy/atoms/elementwise/log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,4 @@ def _domain(self) -> List[Constraint]:
9999
def point_in_domain(self) -> np.ndarray:
100100
"""Returns a point in the domain of the node.
101101
"""
102-
return np.ones(self.size)
102+
return np.ones(self.shape)

cvxpy/atoms/elementwise/rel_entr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,6 @@ def _domain(self):
9595
"""Returns constraints describing the domain of the node.
9696
"""
9797
return [self.args[0] >= 0, self.args[1] >= 0]
98+
99+
def point_in_domain(self, argument=0):
100+
return np.ones(self.args[argument].shape)

cvxpy/reductions/expr2smooth/canonicalizers/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""
1616
from cvxpy.atoms import maximum
1717
from cvxpy.atoms.elementwise.log import log
18+
from cvxpy.atoms.elementwise.entr import entr
19+
from cvxpy.atoms.elementwise.rel_entr import rel_entr
20+
from cvxpy.atoms.elementwise.kl_div import kl_div
1821
from cvxpy.atoms.elementwise.minimum import minimum
1922
from cvxpy.atoms.elementwise.power import power
2023
from cvxpy.atoms.pnorm import Pnorm
@@ -28,6 +31,9 @@
2831
from cvxpy.reductions.expr2smooth.canonicalizers.pnorm_canon import pnorm_canon
2932
from cvxpy.reductions.expr2smooth.canonicalizers.power_canon import power_canon
3033
from cvxpy.reductions.expr2smooth.canonicalizers.maximum_canon import maximum_canon
34+
from cvxpy.reductions.expr2smooth.canonicalizers.entr_canon import entr_canon
35+
from cvxpy.reductions.expr2smooth.canonicalizers.rel_entr_canon import rel_entr_canon
36+
from cvxpy.reductions.expr2smooth.canonicalizers.kl_div_canon import kl_div_canon
3137

3238
CANON_METHODS = {
3339
abs: abs_canon,
@@ -37,5 +43,8 @@
3743
power: power_canon,
3844
Pnorm : pnorm_canon,
3945
DivExpression: div_canon,
40-
multiply: mul_canon
46+
multiply: mul_canon,
47+
entr: entr_canon,
48+
rel_entr: rel_entr_canon,
49+
kl_div: kl_div_canon,
4150
}

cvxpy/reductions/expr2smooth/canonicalizers/abs_canon.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,38 @@
1616

1717
import numpy as np
1818

19+
1920
from cvxpy.expressions.variable import Variable
21+
from cvxpy.atoms.elementwise.power import power
22+
23+
# TODO (DCED): ask William if this the multiplication we want to use
24+
from cvxpy.atoms.affine.binary_operators import multiply
25+
from cvxpy.reductions.expr2smooth.canonicalizers.power_canon import power_canon
2026

27+
#def abs_canon(expr, args):
28+
# shape = expr.shape
29+
# t1 = Variable(shape, bounds = [0, None])
30+
# if expr.value is not None:
31+
# #t1.value = np.sqrt(expr.value**2)
32+
# t1.value = np.abs(args[0].value)
33+
#
34+
# #return t1, [t1**2 == args[0] ** 2]
35+
# square_expr = power(args[0], 2)
36+
# t2, constr_sq = power_canon(square_expr, square_expr.args)
37+
# return t1, [t1**2 == t2] + constr_sq
2138

2239
def abs_canon(expr, args):
2340
shape = expr.shape
24-
t = Variable(shape)
25-
if expr.value is not None:
26-
t.value = np.sqrt(expr.value**2)
27-
return t, [t**2 == expr**2, t >= 0]
41+
t1 = Variable(shape, bounds = [0, None])
42+
y = Variable(shape, bounds = [-1.01, 1.01])
43+
if args[0].value is not None:
44+
#t1.value = np.sqrt(expr.value**2)
45+
t1.value = np.abs(args[0].value)
46+
y.value = np.sign(args[0].value)
47+
48+
t1.value = np.ones(shape)
49+
y.value = np.zeros(shape)
50+
51+
# TODO (DCED): check how multiply is canonicalized. We don't want to introduce a new variable for
52+
# y inside multiply. But args[0] should potentially be canonicalized further?
53+
return t1, [y ** 2 == np.ones(shape), t1 == multiply(y, args[0])]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Copyright 2025 CVXPY developers
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from cvxpy.expressions.variable import Variable
18+
import numpy as np
19+
20+
def entr_canon(expr, args):
21+
t = Variable(args[0].shape, bounds=[0, None])
22+
if args[0].value is not None and np.all(args[0].value >= 1):
23+
t.value = args[0].value
24+
else:
25+
t.value = expr.point_in_domain()
26+
27+
return expr.copy([t]), [t==args[0]]
28+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Copyright 2025 CVXPY developers
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from cvxpy.expressions.variable import Variable
18+
import numpy as np
19+
20+
def kl_div_canon(expr, args):
21+
constraints = []
22+
23+
if not args[0].is_constant():
24+
t1 = Variable(args[0].shape, bounds=[0, None])
25+
constraints.append(t1 == args[0])
26+
27+
if args[0].value is not None and np.all(args[0].value >= 1):
28+
t1.value = args[0].value
29+
else:
30+
t1.value = expr.point_in_domain(argument=0)
31+
else:
32+
t1 = args[0]
33+
34+
if not args[1].is_constant():
35+
t2 = Variable(args[1].shape, bounds=[0, None])
36+
constraints.append(t2 == args[1])
37+
38+
if args[1].value is not None and np.all(args[1].value >= 1):
39+
t2.value = args[1].value
40+
else:
41+
t2.value = expr.point_in_domain(argument=1)
42+
else:
43+
t2 = args[1]
44+
45+
return expr.copy([t1, t2]), constraints

cvxpy/reductions/expr2smooth/canonicalizers/log_canon.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,24 @@
1818

1919
from cvxpy.expressions.constants import Constant
2020
from cvxpy.expressions.variable import Variable
21+
from cvxpy.expressions.constants import Constant
22+
import numpy as np
23+
from cvxpy.atoms.elementwise.exp import exp
24+
25+
def collect_constant_and_variable(expr, constants, variable):
26+
if isinstance(expr, Constant):
27+
constants.append(expr)
28+
elif isinstance(expr, Variable):
29+
variable.append(expr)
30+
elif hasattr(expr, "args"):
31+
for subexpr in expr.args:
32+
collect_constant_and_variable(subexpr, constants, variable)
2133

34+
assert(len(variable) <= 1)
35+
36+
# DCED: Without this lower bound the stress test for ML Gaussian non-zero mean fails.
37+
# Perhaps this should be a parameter exposed to the user?
38+
LOWER_BOUND = 1e-5
2239

2340
def collect_constant_and_variable(expr, constants, variable):
2441
if isinstance(expr, Constant):
@@ -36,7 +53,7 @@ def collect_constant_and_variable(expr, constants, variable):
3653
LOWER_BOUND = 1e-5
3754

3855
def log_canon(expr, args):
39-
t = Variable(args[0].size, bounds=[LOWER_BOUND, None], name='t')
56+
t = Variable(args[0].size, bounds=[LOWER_BOUND, None])
4057

4158
# DCED: if args[0] is a * x for a constant scalar or vector 'a'
4259
# and a vector variable 'x', we want to add bounds to x if x
@@ -71,3 +88,15 @@ def log_canon(expr, args):
7188
t.value = args[0].value
7289

7390
return expr.copy([t]), [t==args[0]]
91+
92+
# TODO (DCED): On some problems this canonicalization seems to work better.
93+
# We should investigate this further when we have more benchmarks
94+
# involving log.
95+
#def log_canon(expr, args):
96+
# t = Variable(args[0].size)
97+
# if args[0].value is not None and np.all(args[0].value > 0):
98+
# t.value = np.log(args[0].value)
99+
# else:
100+
# t.value = expr.point_in_domain()
101+
102+
# return t, [exp(t) == args[0]]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Copyright 2025 CVXPY developers
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from cvxpy.expressions.variable import Variable
18+
19+
def quad_over_lin_canon(expr, args):
20+
assert(False)
21+
pass

0 commit comments

Comments
 (0)