1414limitations under the License.
1515"""
1616
17- from typing import List , Optional , Tuple
17+ from typing import Tuple
1818
1919import numpy as np
2020import scipy .sparse as sp
2323from cvxpy .atoms .atom import Atom
2424
2525
26+ def normalize_axis (
27+ axis : int | tuple [int , ...], ndim : int , reduce_all_to_none : bool = True
28+ ) -> None | int | tuple [int , ...]:
29+ """Normalize an axis argument to a canonical form.
30+
31+ - Negative indices become positive.
32+ - Single-element tuples become an int.
33+ - If all axes are listed and *reduce_all_to_none* is True, returns None.
34+ """
35+ axes = normalize_axis_tuple (axis , ndim )
36+ if reduce_all_to_none and len (axes ) == ndim :
37+ return None
38+ elif len (axes ) == 1 :
39+ return axes [0 ]
40+ else :
41+ return axes
42+
43+
2644class AxisAtom (Atom ):
2745 """
2846 An abstract base class for atoms that can be applied along an axis.
2947 """
3048
31- def __init__ (self , expr , axis : Optional [int ] = None , keepdims : bool = False ) -> None :
49+ # Whether reducing over all axes is equivalent to axis=None.
50+ # True for reduction atoms (sum, max, min, etc.).
51+ # False for cumulative atoms (cumsum, cummax, cumprod) that preserve shape.
52+ _reduce_all_axes_to_none = True
53+
54+ def __init__ (
55+ self , expr , axis : None | int | tuple [int , ...] = None , keepdims : bool = False
56+ ) -> None :
3257 self .axis = axis
3358 self .keepdims = keepdims
3459 super (AxisAtom , self ).__init__ (expr )
60+ # Normalize axis after init so self.args is available.
61+ if self .axis is not None :
62+ ndim = len (self .args [0 ].shape )
63+ if ndim > 0 :
64+ self .axis = normalize_axis (
65+ self .axis , ndim , self ._reduce_all_axes_to_none
66+ )
3567
3668 def shape_from_args (self ) -> Tuple [int , ...]:
3769 """
@@ -75,12 +107,14 @@ def validate_arguments(self) -> None:
75107 _ = normalize_axis_tuple (axes , dim )
76108 super (AxisAtom , self ).validate_arguments ()
77109
78- def _axis_grad (self , values ) -> Optional [ List [ sp .csc_array ]] :
110+ def _axis_grad (self , values ) -> list [ sp .csc_array ] | None :
79111 """
80112 Gives the (sub/super)gradient of the atom w.r.t. each argument.
81113
82114 Matrix expressions are vectorized, so the gradient is a matrix.
83- Takes axis into account.
115+ Takes axis into account. Works for any number of dimensions.
116+
117+ CVXPY convention: grad[i, j] = d(output_flat_F[j]) / d(input_flat_F[i])
84118
85119 Args:
86120 values: A list of numeric values for the arguments.
@@ -93,33 +127,74 @@ def _axis_grad(self, values) -> Optional[List[sp.csc_array]]:
93127 D = self ._column_grad (value )
94128 if D is not None :
95129 D = sp .csc_array (D )
130+ return [D ]
131+
132+ input_shape = self .args [0 ].shape
133+ ndim = len (input_shape )
134+
135+ # Normalize axis to tuple
136+ axis = self .axis
137+ axes = (axis ,) if isinstance (axis , int ) else tuple (axis )
138+ keep = [i for i in range (ndim ) if i not in axes ]
139+
140+ reduce_dims = [input_shape [a ] for a in axes ]
141+ reduce_size = int (np .prod (reduce_dims ))
142+ output_shape = tuple (input_shape [i ] for i in keep )
143+ input_size = int (np .prod (input_shape ))
144+ output_size = max (1 , int (np .prod (output_shape )))
145+
146+ # F-order strides: stride[k] = prod(input_shape[:k])
147+ f_strides = np .ones (ndim , dtype = int )
148+ for k in range (1 , ndim ):
149+ f_strides [k ] = f_strides [k - 1 ] * input_shape [k - 1 ]
150+
151+ # Flat input in F-order
152+ flat_input = values [0 ].ravel (order = 'F' )
153+
154+ # All output multi-indices in F-order
155+ if len (output_shape ) == 0 :
156+ out_multis = np .zeros ((0 , 1 ), dtype = int )
96157 else :
97- m , n = self .args [0 ].shape
98- if self .axis == 0 : # function apply to each column
99- D = sp .csc_array ((m * n , n ), dtype = float )
100- for i in range (n ):
101- value = values [0 ][:, i ]
102- d = self ._column_grad (value ).T
103- if d is None :
104- return [None ]
105- else :
106- d = np .array (d ).flatten ()
107- row = np .linspace (i * m , i * m + m - 1 , m ) # [i*m, i*m+1, ..., i*m+m-1]
108- col = np .ones ((m ))* i
109- D = D + sp .csc_array ((d , (row , col )),
110- shape = (m * n , n )) # d must be 1-D
111- else : # function apply to each row
112- values = np .transpose (values [0 ])
113- D = sp .csc_array ((m * n , m ), dtype = float )
114- for i in range (m ):
115- value = values [:, i ]
116- d = self ._column_grad (value ).T
117- if d is None :
118- return [None ]
119- row = np .linspace (i , i + (n - 1 )* m , n ) # [0+i, m+i, ..., m(n-1)+i]
120- col = np .ones ((n ))* i
121- D = D + sp .csc_array ((np .array (d )[0 ], (row , col )),
122- shape = (m * n , m )) # d must be 1-D
158+ out_multis = np .array (
159+ np .unravel_index (np .arange (output_size ), output_shape , order = 'F' )
160+ ) # shape: (len(keep), output_size)
161+
162+ # All reduce-axis multi-indices
163+ reduce_multis = np .array (
164+ np .unravel_index (np .arange (reduce_size ), reduce_dims )
165+ ).T # shape: (reduce_size, len(axes))
166+
167+ all_rows = []
168+ all_cols = []
169+ all_data = []
170+
171+ for j in range (output_size ):
172+ om = out_multis [:, j ]
173+
174+ # Build input multi-indices: fix keep axes, vary reduce axes
175+ in_multis = np .zeros ((reduce_size , ndim ), dtype = int )
176+ for idx , k in enumerate (keep ):
177+ in_multis [:, k ] = om [idx ]
178+ for idx , a in enumerate (axes ):
179+ in_multis [:, a ] = reduce_multis [:, idx ]
180+
181+ # Compute flat F-order indices for this fiber
182+ fiber_indices = in_multis @ f_strides
183+ fiber_values = flat_input [fiber_indices ]
184+
185+ d = self ._column_grad (fiber_values .reshape (- 1 , 1 ))
186+ if d is None :
187+ return [None ]
188+ d = np .asarray (d ).flatten ()
189+
190+ all_rows .append (fiber_indices )
191+ all_cols .append (np .full (reduce_size , j , dtype = int ))
192+ all_data .append (d )
193+
194+ rows = np .concatenate (all_rows )
195+ cols = np .concatenate (all_cols )
196+ data = np .concatenate (all_data )
197+ D = sp .csc_array ((data , (rows , cols )), shape = (input_size , output_size ))
123198 return [D ]
124199
125200 def _column_grad (self , value ):
0 commit comments