Skip to content

Commit 3b5eb4e

Browse files
Add cp.stack atom (cvxpy#2956)
* Add cp.stack atom * rebase * address comments: remove erroneous import, add canonicalization test * Run tests against editable install * Enable editable install for OpenMP builds * Use editable install in backend workflow * Fix Windows install mode * Revert Windows-specific pytest args
1 parent 20ad799 commit 3b5eb4e

File tree

6 files changed

+229
-5
lines changed

6 files changed

+229
-5
lines changed

.github/workflows/test_backends.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@ jobs:
1818
- uses: actions/checkout@v5
1919
- name: Install cvxpy dependencies
2020
run: |
21-
pip install .
21+
pip install -e .
2222
pip install pytest hypothesis
2323
- name: Run tests for each non-default backend
2424
run : |
2525
export CVXPY_DEFAULT_CANON_BACKEND="SCIPY"
2626
python -c "from cvxpy.cvxcore.python.canonInterface import get_default_canon_backend; print(get_default_canon_backend())"
2727
python -c "from cvxpy.cvxcore.python.canonInterface import get_default_canon_backend; assert get_default_canon_backend() == 'SCIPY'"
2828
pytest
29-

continuous_integration/test_script.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ python -c "import numpy; print('numpy %s' % numpy.__version__)"
1616
python -c "import scipy; print('scipy %s' % scipy.__version__)"
1717

1818
if [ $USE_OPENMP == "True" ] && [ $RUNNER_OS == "Linux" ]; then
19-
CFLAGS="-fopenmp" LDFLAGS="-lgomp" uv pip install .
19+
CFLAGS="-fopenmp" LDFLAGS="-lgomp" uv pip install -e .
2020
export OMP_NUM_THREADS=4
2121
else
2222
uv pip list
23-
uv pip install .
23+
uv pip install -e .
2424
fi
2525

2626
python -c "import cvxpy; print(cvxpy.installed_solvers())"

cvxpy/atoms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from cvxpy.atoms.affine.reshape import deep_flatten, reshape
3535
from cvxpy.atoms.affine.squeeze import squeeze
3636
from cvxpy.atoms.affine.concatenate import concatenate
37+
from cvxpy.atoms.affine.stack import stack
3738
from cvxpy.atoms.affine.sum import sum
3839
from cvxpy.atoms.affine.trace import trace, Trace
3940
from cvxpy.atoms.affine.transpose import (transpose, permute_dims, swapaxes, moveaxis)

cvxpy/atoms/affine/stack.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
Copyright, the CVXPY authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
16+
NumPy-style stacking that inserts a new axis for CVXPY Expressions.
17+
18+
This mirrors the semantics of ``np.stack``, where they make sense, for symbolic
19+
expressions: every input must share an identical shape, and the result gains a
20+
new axis of length ``len(arrays)`` at the requested position. Options that are
21+
specific to NumPy's ndarray type system (for example ``out`` or ``dtype``) are
22+
intentionally unsupported.
23+
24+
Examples
25+
--------
26+
>>> import cvxpy as cp
27+
>>> a = cp.Parameter((3,))
28+
>>> b = cp.Parameter((3,))
29+
>>> y = cp.stack([a, b], axis=0)
30+
>>> y.shape
31+
(2, 3)
32+
>>> z = cp.stack([a, b], axis=-1)
33+
>>> z.shape
34+
(3, 2)
35+
"""
36+
37+
from typing import Iterable, Sequence
38+
39+
from cvxpy.atoms.affine.concatenate import concatenate
40+
from cvxpy.atoms.affine.reshape import reshape
41+
from cvxpy.expressions.expression import Expression
42+
43+
44+
def _as_expression(obj) -> Expression:
45+
"""Cast scalars/arrays to ``Constant`` while leaving Expressions intact."""
46+
return obj if isinstance(obj, Expression) else Expression.cast_to_const(obj)
47+
48+
49+
def stack(arrays: Sequence[object] | Iterable[object], axis: int = 0) -> Expression:
50+
"""Join a sequence of expressions along a new axis.
51+
52+
Parameters
53+
----------
54+
arrays
55+
Sequence of expressions (or array-likes) that all have the same shape.
56+
axis
57+
Index of the new axis in the result. Values in ``[-(ndim + 1), ndim``
58+
``+ 1)`` are accepted, following ``numpy.stack``.
59+
60+
Returns
61+
-------
62+
Expression
63+
Expression whose shape equals the common input shape with the new axis
64+
inserted at ``axis`` and length ``len(arrays)`` along that axis.
65+
66+
Raises
67+
------
68+
TypeError
69+
If ``axis`` is not an integer.
70+
ValueError
71+
If ``arrays`` is empty, shapes differ, or ``axis`` is out of bounds.
72+
"""
73+
xs = [_as_expression(arg) for arg in arrays]
74+
if not xs:
75+
raise ValueError("need at least one array to stack")
76+
77+
if not isinstance(axis, int):
78+
raise TypeError(f"axis must be an int; received {type(axis).__name__}")
79+
80+
shapes = {expr.shape for expr in xs}
81+
if len(shapes) != 1:
82+
raise ValueError(
83+
"all input arrays must have the same shape; got "
84+
f"{sorted(shapes)}"
85+
)
86+
87+
base_shape = xs[0].shape
88+
result_ndim = len(base_shape) + 1
89+
if not (-result_ndim <= axis < result_ndim):
90+
raise ValueError(
91+
f"axis {axis} is out of bounds for result ndim {result_ndim}"
92+
)
93+
94+
axis_index = axis if axis >= 0 else axis + result_ndim
95+
# Slice the shape so we can splice a singleton axis where ``axis`` points.
96+
prefix = base_shape[:axis_index]
97+
suffix = base_shape[axis_index:]
98+
# Reshape each argument to inject the new length-1 axis before concatenating.
99+
reshaped = [
100+
reshape(expr, prefix + (1,) + suffix, order='F')
101+
for expr in xs
102+
]
103+
104+
return concatenate(reshaped, axis=axis_index)
105+
106+
107+
__all__ = ["stack"]
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Copyright, the CVXPY authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
16+
Unit tests for the NumPy-style ``cp.stack`` helper. Each test focuses on a
17+
single aspect of the API contract so failures clearly indicate which axis
18+
handling or validation rule regressed.
19+
"""
20+
21+
import numpy as np
22+
import pytest
23+
24+
import cvxpy as cp
25+
26+
27+
def test_stack_1d_axis0() -> None:
28+
"""Two 1-D Parameters stack along axis 0 to form a 2x4 array."""
29+
a = cp.Parameter((4,))
30+
b = cp.Parameter((4,))
31+
y = cp.stack([a, b], axis=0)
32+
assert y.shape == (2, 4)
33+
34+
35+
def test_stack_1d_axis_last_numeric_parity() -> None:
36+
"""Scalar arrays stacked along the last axis match NumPy numerically."""
37+
a = np.array([1.0, 2.0, 3.0])
38+
b = np.array([4.0, 5.0, 6.0])
39+
y = cp.stack([a, b], axis=-1)
40+
assert y.shape == (3, 2)
41+
expected = np.stack([a, b], axis=-1)
42+
assert np.allclose(y.value, expected)
43+
44+
45+
def test_stack_2d_various_axes() -> None:
46+
"""Validate axis normalization on 2-D operands for several positions."""
47+
a = cp.Parameter((3, 4))
48+
b = cp.Parameter((3, 4))
49+
assert cp.stack([a, b], axis=0).shape == (2, 3, 4)
50+
assert cp.stack([a, b], axis=1).shape == (3, 2, 4)
51+
assert cp.stack([a, b], axis=-1).shape == (3, 4, 2)
52+
53+
54+
def test_stack_scalar_inputs() -> None:
55+
"""Literal scalars auto-wrap into Constants and stack as a 1-D vector."""
56+
y = cp.stack([1, 2, 3], axis=0)
57+
assert y.shape == (3,)
58+
assert np.allclose(y.value, np.array([1, 2, 3]))
59+
60+
61+
def test_stack_shape_mismatch_raises() -> None:
62+
"""Inputs of different shapes trigger the NumPy-style ValueError."""
63+
a = cp.Parameter((3,))
64+
b = cp.Parameter((4,))
65+
with pytest.raises(ValueError):
66+
cp.stack([a, b], axis=0)
67+
68+
69+
def test_stack_empty_list_raises() -> None:
70+
"""An empty input sequence is rejected."""
71+
with pytest.raises(ValueError):
72+
cp.stack([], axis=0)
73+
74+
75+
def test_stack_axis_bounds_check() -> None:
76+
"""Axis validation mirrors NumPy bounds for the resulting ndim."""
77+
a = cp.Parameter((3,))
78+
# ndim=1 -> result_ndim=2; valid axes: -2, -1, 0, 1
79+
for good in (-2, -1, 0, 1):
80+
cp.stack([a, a], axis=good)
81+
for bad in (-3, 2, 3):
82+
with pytest.raises(ValueError):
83+
cp.stack([a, a], axis=bad)
84+
85+
86+
def test_stack_non_int_axis_raises() -> None:
87+
"""Non-integer axes raise a TypeError before shape checks run."""
88+
a = cp.Parameter((2,))
89+
with pytest.raises(TypeError):
90+
cp.stack([a, a], axis=0.5)
91+
92+
93+
def test_stack_variables_shape_only() -> None:
94+
"""Variables share shape metadata with the stacked expression."""
95+
x = cp.Variable((2, 3))
96+
y = cp.Variable((2, 3))
97+
z = cp.stack([x, y], axis=0)
98+
assert z.shape == (2, 2, 3)
99+
100+
101+
def test_stack_canonicalization_resolves_equalities() -> None:
102+
"""Canonicalization maps scalar variables onto the stacked vector."""
103+
x = cp.Variable()
104+
y = cp.Variable()
105+
z = cp.Variable(2)
106+
z_tilde = cp.stack([x, y])
107+
problem = cp.Problem(cp.Minimize(0), [z_tilde == z, x == 1, y == 2])
108+
problem.solve(solver=cp.SCS)
109+
assert problem.status == cp.OPTIMAL
110+
assert np.allclose(z.value, np.array([1.0, 2.0]))

doc/source/api_reference/cvxpy.atoms.affine.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,11 @@ vec_to_upper_tri
271271
vstack
272272
-----------------------------------
273273

274-
.. autofunction:: cvxpy.vstack
274+
.. autofunction:: cvxpy.vstack
275+
276+
.. _stack:
277+
278+
stack
279+
-----------------------------------
280+
281+
.. autofunction:: cvxpy.stack

0 commit comments

Comments
 (0)