66
77import pytensor .scalar as ps
88from pytensor .compile .function import function
9- from pytensor .gradient import grad , grad_not_implemented , jacobian
10- from pytensor .graph import rewrite_graph
9+ from pytensor .gradient import DisconnectedType , grad , grad_not_implemented , jacobian
1110from pytensor .graph .basic import Apply , Constant
1211from pytensor .graph .fg import FunctionGraph
1312from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
@@ -255,16 +254,31 @@ def scalar_implict_optimization_grads(
255254 output_grad : TensorVariable ,
256255 fgraph : FunctionGraph ,
257256) -> list [TensorVariable | ScalarVariable ]:
257+ inner_args_to_diff = []
258+ outer_args_to_diff = []
259+ for inner_arg , outer_arg in zip (inner_args , args ):
260+ if inner_arg .type .dtype .startswith ("float" ):
261+ inner_args_to_diff .append (inner_arg )
262+ outer_args_to_diff .append (outer_arg )
263+
264+ if len (args ) > 0 and not inner_args_to_diff :
265+ # No differentiable arguments, return disconnected gradients
266+ return [DisconnectedType ()() for _ in args ]
267+
258268 df_dx , * df_dthetas = grad (
259- inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
269+ inner_fx , [inner_x , * inner_args_to_diff ], disconnected_inputs = "ignore"
260270 )
261271
262272 replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
263273 df_dx_star , * df_dthetas_stars = graph_replace ([df_dx , * df_dthetas ], replace = replace )
264274
275+ arg_to_grad = dict (zip (outer_args_to_diff , df_dthetas_stars ))
276+
265277 grad_wrt_args = [
266- (- df_dtheta_star / df_dx_star ) * output_grad
267- for df_dtheta_star in df_dthetas_stars
278+ (- arg_to_grad [arg ] / df_dx_star ) * output_grad
279+ if arg in arg_to_grad
280+ else DisconnectedType ()()
281+ for arg in args
268282 ]
269283
270284 return grad_wrt_args
@@ -317,10 +331,26 @@ def implict_optimization_grads(
317331 fgraph : FunctionGraph
318332 The function graph that contains the inputs and outputs of the optimization problem.
319333 """
334+
335+ # There might be non-differentiable arguments along the compute path from the objective to the inputs. Notably,
336+ # integers often arise due to Shape ops called by pack/unpack. These will be given DisconnectedType gradients.
337+ # First, they are filtered out before calling jacobian.
338+ inner_args_to_diff = []
339+ outer_args_to_diff = []
340+ for inner_arg , outer_arg in zip (inner_args , args ):
341+ if inner_arg .type .dtype .startswith ("float" ):
342+ inner_args_to_diff .append (inner_arg )
343+ outer_args_to_diff .append (outer_arg )
344+
345+ if len (args ) > 0 and not inner_args_to_diff :
346+ # No differentiable arguments, return disconnected gradients
347+ return [DisconnectedType ()() for _ in args ]
348+
349+ # Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op.
320350 packed_inner_args , packed_arg_shapes , implicit_f = (
321351 _maybe_pack_input_variables_and_rewrite_objective (
322352 implicit_f ,
323- inner_args ,
353+ inner_args_to_diff ,
324354 )
325355 )
326356
@@ -331,9 +361,11 @@ def implict_optimization_grads(
331361 vectorize = use_vectorized_jac ,
332362 )
333363
364+ # Replace inner inputs (abstract dummies) with outer inputs (the actual user-provided symbols)
365+ # at the solution point. From here on, the inner values should not be referenced.
334366 inner_to_outer_map = dict (zip (fgraph .inputs , (x_star , * args )))
335-
336367 df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], inner_to_outer_map )
368+
337369 grad_wrt_args_packed = solve (- atleast_2d (df_dx_star ), atleast_1d (df_dtheta_star ))
338370
339371 if packed_arg_shapes is not None :
@@ -351,16 +383,23 @@ def implict_optimization_grads(
351383 grad_wrt_args_packed = grad_wrt_args_packed .squeeze (axis = 0 )
352384 grad_wrt_args = [grad_wrt_args_packed ]
353385
354- final_grads = [
355- tensordot (output_grad , arg_grad , [[0 ], [0 ]])
356- if arg_grad .ndim > 0 and output_grad .ndim > 0
357- else arg_grad * output_grad
358- for arg_grad in grad_wrt_args
359- ]
360- final_grads = [
361- scalar_from_tensor (g ) if isinstance (arg .type , ScalarType ) else g
362- for arg , g in zip (args , final_grads )
363- ]
386+ arg_to_grad = dict (zip (outer_args_to_diff , grad_wrt_args ))
387+
388+ final_grads = []
389+ for arg in args :
390+ arg_grad = arg_to_grad .get (arg , None )
391+
392+ if arg_grad is None :
393+ final_grads .append (DisconnectedType ()())
394+ continue
395+
396+ if arg_grad .ndim > 0 and output_grad .ndim > 0 :
397+ g = tensordot (output_grad , arg_grad , [[0 ], [0 ]])
398+ else :
399+ g = arg_grad * output_grad
400+ if isinstance (arg .type , ScalarType ):
401+ g = scalar_from_tensor (g )
402+ final_grads .append (g )
364403
365404 return final_grads
366405
@@ -640,7 +679,7 @@ def _maybe_pack_input_variables_and_rewrite_objective(
640679 for xi , ui in zip (x , unpacked_output )
641680 },
642681 )
643- objective = rewrite_graph ( objective , include = ( "ShapeOpt" , "canonicalize" ))
682+
644683 return packed_input , packed_shapes , objective
645684
646685
0 commit comments