Skip to content

Commit 3c6194e

Browse files
committed
feat: DBCP convolve atom
1 parent fc1309f commit 3c6194e

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

src/dbcp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from .atoms import convolve
12
from .problem import BiconvexProblem, BiconvexRelaxProblem
23

3-
__all__ = ["BiconvexProblem", "BiconvexRelaxProblem", "__version__"]
4+
__all__ = ["BiconvexProblem", "BiconvexRelaxProblem", "convolve", "__version__"]
45

56
try:
67
from importlib.metadata import version, PackageNotFoundError

src/dbcp/atoms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import cvxpy as cp
2+
3+
4+
def convolve(x, y):
5+
"""Discrete convolution of two 1-D cvxpy expressions.
6+
7+
Suppose :math:`x` and :math:`y` are 1-D cvxpy expressions of lengths
8+
:math:`m` and :math:`n`, respectively.
9+
This function returns a cvxpy expression :math:`c` of length :math:`m + n - 1`, where
10+
11+
.. math::
12+
13+
c_k = \\sum_{i + j = k} x_i y_j,\\quad k = 1, \\ldots, m + n - 1.
14+
15+
Matches numpy.convolve for 1-D arrays.
16+
17+
This function extends cvxpy.convolve atom to support the convolution
18+
operation between two cvxpy expressions.
19+
20+
Parameters
21+
----------
22+
x : cp.Expression
23+
A 1-D cvxpy expression.
24+
y : cp.Expression
25+
A 1-D cvxpy expression.
26+
27+
Returns
28+
-------
29+
cp.Expression
30+
The convolution of x and y.
31+
"""
32+
if x.ndim != 1 or y.ndim != 1:
33+
raise ValueError("Both inputs must be 1-D cvxpy expressions.")
34+
35+
c = [0] * (x.shape[0] + y.shape[0] - 1)
36+
for i, a in enumerate(y):
37+
for j, b in enumerate(x):
38+
c[i + j] += a * b
39+
return cp.hstack(c)

tests/test_atoms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
import cvxpy as cp
3+
import numpy as np
4+
import dbcp
5+
6+
np.random.seed(10015)
7+
8+
9+
def test_convolve():
10+
# Test the convolution atom to match np.convolve
11+
x0 = np.random.randn(10)
12+
y0 = np.random.randn(20)
13+
c0 = np.convolve(x0, y0)
14+
15+
x = cp.Variable(10)
16+
y = cp.Variable(20)
17+
c = dbcp.convolve(x, y)
18+
19+
x.value = x0
20+
y.value = y0
21+
assert np.allclose(c.value, c0)
22+
23+
24+
def test_convolve_invalid_input():
25+
x = cp.Variable((3, 4))
26+
y = cp.Variable(5)
27+
with pytest.raises(ValueError):
28+
dbcp.convolve(x, y)
29+
30+
x = cp.Variable(6)
31+
y = cp.Variable((2, 3))
32+
with pytest.raises(ValueError):
33+
dbcp.convolve(x, y)

0 commit comments

Comments
 (0)