@@ -128,6 +128,32 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128128 ]
129129
130130
131+ def _get_parameter_grads_from_vector (
132+ grad_wrt_args_vector : Variable ,
133+ x_star : Variable ,
134+ args : Sequence [Variable ],
135+ output_grad : Variable ,
136+ ):
137+ """
138+ Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
139+ returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
140+ """
141+ cursor = 0
142+ grad_wrt_args = []
143+
144+ for arg in args :
145+ arg_shape = arg .shape
146+ arg_size = arg_shape .prod ()
147+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
148+ (* x_star .shape , * arg_shape )
149+ )
150+
151+ grad_wrt_args .append (dot (output_grad , arg_grad ))
152+ cursor += arg_size
153+
154+ return grad_wrt_args
155+
156+
131157class ScipyWrapperOp (Op , HasInnerGraph ):
132158 """Shared logic for scipy optimization ops"""
133159
@@ -348,34 +374,25 @@ def L_op(self, inputs, outputs, output_grads):
348374
349375 implicit_f = grad (inner_fx , inner_x )
350376
351- df_dx = atleast_2d (concatenate (jacobian (implicit_f , [inner_x ]), axis = - 1 ))
377+ df_dx , * df_dtheta_columns = jacobian (
378+ implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
379+ )
352380
353381 df_dtheta = concatenate (
354- [
355- atleast_2d (x , left = False )
356- for x in jacobian (implicit_f , inner_args , disconnected_inputs = "ignore" )
357- ],
382+ [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
358383 axis = - 1 ,
359384 )
360385
361386 replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
362387
363- df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
388+ df_dx_star , df_dtheta_star = graph_replace (
389+ [atleast_2d (df_dx ), df_dtheta ], replace = replace
390+ )
364391
365392 grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
366-
367- cursor = 0
368- grad_wrt_args = []
369-
370- for arg in args :
371- arg_shape = arg .shape
372- arg_size = arg_shape .prod ()
373- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
374- (* x_star .shape , * arg_shape )
375- )
376-
377- grad_wrt_args .append (dot (output_grad , arg_grad ))
378- cursor += arg_size
393+ grad_wrt_args = _get_parameter_grads_from_vector (
394+ grad_wrt_args_vector , x_star , args , output_grad
395+ )
379396
380397 return [zeros_like (x ), * grad_wrt_args ]
381398
@@ -504,19 +521,9 @@ def L_op(
504521 df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
505522
506523 grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
507-
508- cursor = 0
509- grad_wrt_args = []
510-
511- for arg in args :
512- arg_shape = arg .shape
513- arg_size = arg_shape .prod ()
514- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
515- (* x_star .shape , * arg_shape )
516- )
517-
518- grad_wrt_args .append (dot (output_grad , arg_grad ))
519- cursor += arg_size
524+ grad_wrt_args = _get_parameter_grads_from_vector (
525+ grad_wrt_args_vector , x_star , args , output_grad
526+ )
520527
521528 return [zeros_like (x ), * grad_wrt_args ]
522529
@@ -529,11 +536,7 @@ def root(
529536):
530537 """Find roots of a system of equations using scipy.optimize.root."""
531538
532- args = [
533- arg
534- for arg in truncated_graph_inputs ([equations ], [variables ])
535- if (arg is not variables and not isinstance (arg , Constant ))
536- ]
539+ args = _find_optimization_parameters (equations , variables )
537540
538541 root_op = RootOp (variables , * args , equations = equations , method = method , jac = jac )
539542
0 commit comments