@@ -109,58 +109,54 @@ def is_nsd(self) -> bool:
109109 return True
110110
111111 def _grad (self , values ) -> List [Any ]:
112- """Gives the (sub/super) gradient of the atom w.r.t. each argument.
112+ """Computes the gradient of the affine atom w.r.t. each argument.
113113
114- Matrix expressions are vectorized, so the gradient is a matrix.
114+ For affine atoms, the gradient is constant and independent of argument values.
115+ We compute it by constructing the canonical matrix representation and extracting
116+ the linear coefficients.
115117
116118 Args:
117- values: A list of numeric values for the arguments .
119+ values: Argument values (unused for affine atoms) .
118120
119121 Returns:
120- A list of SciPy CSC sparse matrices or None .
122+ List of gradient matrices, one for each argument .
121123 """
122- # TODO should be a simple function in cvxcore for this.
123- # Make a fake lin op tree for the function.
124+ # Create fake variables for each non-constant argument to build the linear system
124125 fake_args = []
125126 var_offsets = {}
126- offset = 0
127+ var_length = 0
128+
127129 for idx , arg in enumerate (self .args ):
128130 if arg .is_constant ():
129- fake_args += [ Constant (arg .value ).canonical_form [0 ]]
131+ fake_args . append ( Constant (arg .value ).canonical_form [0 ])
130132 else :
131- fake_args += [lu .create_var (arg .shape , idx )]
132- var_offsets [idx ] = offset
133- offset += arg .size
134- var_length = offset
135- fake_expr , _ = self .graph_implementation (fake_args , self .shape ,
136- self .get_data ())
137- param_to_size = {lo .CONSTANT_ID : 1 }
138- param_to_col = {lo .CONSTANT_ID : 0 }
139- # Get the matrix representation of the function.
133+ fake_args .append (lu .create_var (arg .shape , idx ))
134+ var_offsets [idx ] = var_length
135+ var_length += arg .size
136+
137+ # Get the canonical matrix representation: f(x) = Ax + b
138+ fake_expr , _ = self .graph_implementation (fake_args , self .shape , self .get_data ())
140139 canon_mat = canonInterface .get_problem_matrix (
141- [fake_expr ],
142- var_length ,
143- var_offsets ,
144- param_to_size ,
145- param_to_col ,
146- self .size ,
140+ [fake_expr ], var_length , var_offsets ,
141+ {lo .CONSTANT_ID : 1 }, {lo .CONSTANT_ID : 0 }, self .size
147142 )
148- # HACK TODO TODO convert tensors back to vectors.
149- # COO = (V[lo.CONSTANT_ID][0], (J[lo.CONSTANT_ID][0], I[lo.CONSTANT_ID][0]) )
150- shape = ( var_length + 1 , self .size )
151- stacked_grad = canon_mat . reshape ( shape ). tocsc ()[: - 1 , :]
152- # Break up into per argument matrices.
143+
144+ # Extract gradient matrix A (exclude constant offset b )
145+ grad_matrix = canon_mat . reshape (( var_length + 1 , self .size )). tocsc ()[: - 1 , :]
146+
147+ # Split gradients by argument
153148 grad_list = []
154- start = 0
149+ var_start = 0
155150 for arg in self .args :
156151 if arg .is_constant ():
157- grad_shape = (arg .size , shape [1 ])
158- if grad_shape == (1 , 1 ):
159- grad_list += [0 ]
160- else :
161- grad_list += [sp .coo_matrix (grad_shape , dtype = 'float64' )]
152+ # Zero gradient for constants
153+ grad_shape = (arg .size , self .size )
154+ grad_list .append (0 if grad_shape == (1 , 1 ) else
155+ sp .coo_matrix (grad_shape , dtype = 'float64' ))
162156 else :
163- stop = start + arg .size
164- grad_list += [stacked_grad [start :stop , :]]
165- start = stop
157+ # Extract gradient block for this variable
158+ var_end = var_start + arg .size
159+ grad_list .append (grad_matrix [var_start :var_end , :])
160+ var_start = var_end
161+
166162 return grad_list
0 commit comments