1111import pytensor
1212from pytensor .compile .ops import ViewOp
1313from pytensor .configdefaults import config
14- from pytensor .graph import utils
14+ from pytensor .graph import utils , vectorize_graph
1515from pytensor .graph .basic import Apply , NominalVariable , Variable
1616from pytensor .graph .null_type import NullType , null_type
1717from pytensor .graph .op import get_test_values
@@ -703,15 +703,15 @@ def grad(
703703 grad_dict [var ] = g_var
704704
705705 def handle_disconnected (var ):
706- message = (
707- "grad method was asked to compute the gradient "
708- "with respect to a variable that is not part of "
709- "the computational graph of the cost, or is used "
710- f"only by a non-differentiable operator: { var } "
711- )
712706 if disconnected_inputs == "ignore" :
713- pass
707+ return
714708 elif disconnected_inputs == "warn" :
709+ message = (
710+ "grad method was asked to compute the gradient "
711+ "with respect to a variable that is not part of "
712+ "the computational graph of the cost, or is used "
713+ f"only by a non-differentiable operator: { var } "
714+ )
715715 warnings .warn (message , stacklevel = 2 )
716716 elif disconnected_inputs == "raise" :
717717 message = utils .get_variable_trace_string (var )
@@ -2021,13 +2021,19 @@ def __str__(self):
20212021Exception args: { args_msg } """
20222022
20232023
2024- def jacobian (expression , wrt , consider_constant = None , disconnected_inputs = "raise" ):
2024+ def jacobian (
2025+ expression ,
2026+ wrt ,
2027+ consider_constant = None ,
2028+ disconnected_inputs = "raise" ,
2029+ vectorize = False ,
2030+ ):
20252031 """
20262032 Compute the full Jacobian, row by row.
20272033
20282034 Parameters
20292035 ----------
2030- expression : Vector (1-dimensional) : class:`~pytensor.graph.basic.Variable`
2036+ expression :class:`~pytensor.graph.basic.Variable`
20312037 Values that we are differentiating (that we want the Jacobian of)
20322038 wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables
20332039 Term[s] with respect to which we compute the Jacobian
@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20512057 output, then a zero variable is returned. The return value is
20522058 of same type as `wrt`: a list/tuple or TensorVariable in all cases.
20532059 """
2060+ from pytensor .tensor import broadcast_to , eye
20542061
20552062 if not isinstance (expression , Variable ):
20562063 raise TypeError ("jacobian expects a Variable as `expression`" )
20572064
2058- if expression .ndim > 1 :
2059- raise ValueError (
2060- "jacobian expects a 1 dimensional variable as `expression`."
2061- " If not use flatten to make it a vector"
2062- )
2063-
20642065 using_list = isinstance (wrt , list )
20652066 using_tuple = isinstance (wrt , tuple )
2067+ grad_kwargs = {
2068+ "consider_constant" : consider_constant ,
2069+ "disconnected_inputs" : disconnected_inputs ,
2070+ }
20662071
20672072 if isinstance (wrt , list | tuple ):
20682073 wrt = list (wrt )
20692074 else :
20702075 wrt = [wrt ]
20712076
20722077 if all (expression .type .broadcastable ):
2073- # expression is just a scalar, use grad
2074- return as_list_or_tuple (
2075- using_list ,
2076- using_tuple ,
2077- grad (
2078- expression .squeeze (),
2079- wrt ,
2080- consider_constant = consider_constant ,
2081- disconnected_inputs = disconnected_inputs ,
2082- ),
2078+ jacobian_matrices = grad (expression .squeeze (), wrt , ** grad_kwargs )
2079+
2080+ elif vectorize :
2081+ expression_flat = expression .ravel ()
2082+ row_tangent = _float_ones_like (expression_flat ).type ("row_tangent" )
2083+ jacobian_single_rows = Lop (expression .ravel (), wrt , row_tangent , ** grad_kwargs )
2084+
2085+ n_rows = expression_flat .size
2086+ jacobian_matrices = vectorize_graph (
2087+ jacobian_single_rows ,
2088+ replace = {row_tangent : eye (n_rows , dtype = row_tangent .dtype )},
20832089 )
2090+ if disconnected_inputs != "raise" :
2091+ # If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian
2092+ # We have to broadcast the zeros explicitly here
2093+ for i , (jacobian_single_row , jacobian_matrix ) in enumerate (
2094+ zip (jacobian_single_rows , jacobian_matrices , strict = True )
2095+ ):
2096+ if jacobian_single_row .ndim == jacobian_matrix .ndim :
2097+ jacobian_matrices [i ] = broadcast_to (
2098+ jacobian_matrix , shape = (n_rows , * jacobian_matrix .shape )
2099+ )
20842100
2085- def inner_function (* args ):
2086- idx = args [0 ]
2087- expr = args [1 ]
2088- rvals = []
2089- for inp in args [2 :]:
2090- rval = grad (
2091- expr [idx ],
2092- inp ,
2093- consider_constant = consider_constant ,
2094- disconnected_inputs = disconnected_inputs ,
2101+ else :
2102+
2103+ def inner_function (* args ):
2104+ idx , expr , * wrt = args
2105+ return grad (expr [idx ], wrt , ** grad_kwargs )
2106+
2107+ jacobian_matrices , updates = pytensor .scan (
2108+ inner_function ,
2109+ sequences = pytensor .tensor .arange (expression .size ),
2110+ non_sequences = [expression .ravel (), * wrt ],
2111+ return_list = True ,
2112+ )
2113+ if updates :
2114+ raise ValueError (
2115+ "The scan used to build the jacobian matrices returned a list of updates"
20952116 )
2096- rvals .append (rval )
2097- return rvals
2098-
2099- # Computing the gradients does not affect the random seeds on any random
2100- # generator used n expression (because during computing gradients we are
2101- # just backtracking over old values. (rp Jan 2012 - if anyone has a
2102- # counter example please show me)
2103- jacobs , updates = pytensor .scan (
2104- inner_function ,
2105- sequences = pytensor .tensor .arange (expression .shape [0 ]),
2106- non_sequences = [expression , * wrt ],
2107- )
2108- assert not updates , "Scan has returned a list of updates; this should not happen."
2109- return as_list_or_tuple (using_list , using_tuple , jacobs )
2117+
2118+ if jacobian_matrices [0 ].ndim < (expression .ndim + wrt [0 ].ndim ):
2119+ # There was some raveling or squeezing done prior to getting the jacobians
2120+ # Reshape into original shapes
2121+ jacobian_matrices = [
2122+ jac_matrix .reshape ((* expression .shape , * w .shape ))
2123+ for jac_matrix , w in zip (jacobian_matrices , wrt , strict = True )
2124+ ]
2125+
2126+ return as_list_or_tuple (using_list , using_tuple , jacobian_matrices )
21102127
21112128
21122129def hessian (cost , wrt , consider_constant = None , disconnected_inputs = "raise" ):
0 commit comments