1111import pytensor
1212from pytensor .compile .ops import ViewOp
1313from pytensor .configdefaults import config
14- from pytensor .graph import utils
14+ from pytensor .graph import utils , vectorize_graph
1515from pytensor .graph .basic import Apply , NominalVariable , Variable
1616from pytensor .graph .null_type import NullType , null_type
1717from pytensor .graph .op import get_test_values
@@ -703,15 +703,15 @@ def grad(
703703 grad_dict [var ] = g_var
704704
705705 def handle_disconnected (var ):
706- message = (
707- "grad method was asked to compute the gradient "
708- "with respect to a variable that is not part of "
709- "the computational graph of the cost, or is used "
710- f"only by a non-differentiable operator: { var } "
711- )
712706 if disconnected_inputs == "ignore" :
713- pass
707+ return
714708 elif disconnected_inputs == "warn" :
709+ message = (
710+ "grad method was asked to compute the gradient "
711+ "with respect to a variable that is not part of "
712+ "the computational graph of the cost, or is used "
713+ f"only by a non-differentiable operator: { var } "
714+ )
715715 warnings .warn (message , stacklevel = 2 )
716716 elif disconnected_inputs == "raise" :
717717 message = utils .get_variable_trace_string (var )
@@ -2021,13 +2021,19 @@ def __str__(self):
20212021Exception args: { args_msg } """
20222022
20232023
2024- def jacobian (expression , wrt , consider_constant = None , disconnected_inputs = "raise" ):
2024+ def jacobian (
2025+ expression ,
2026+ wrt ,
2027+ consider_constant = None ,
2028+ disconnected_inputs = "raise" ,
2029+ vectorize = False ,
2030+ ):
20252031 """
20262032 Compute the full Jacobian, row by row.
20272033
20282034 Parameters
20292035 ----------
2030- expression : Vector (1-dimensional) : class:`~pytensor.graph.basic.Variable`
2036+ expression :class:`~pytensor.graph.basic.Variable`
20312037 Values that we are differentiating (that we want the Jacobian of)
20322038 wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables
20332039 Term[s] with respect to which we compute the Jacobian
@@ -2051,62 +2057,74 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20512057 output, then a zero variable is returned. The return value is
20522058 of same type as `wrt`: a list/tuple or TensorVariable in all cases.
20532059 """
2060+ from pytensor .tensor .basic import eye
2061+ from pytensor .tensor .extra_ops import broadcast_to
20542062
20552063 if not isinstance (expression , Variable ):
20562064 raise TypeError ("jacobian expects a Variable as `expression`" )
20572065
2058- if expression .ndim > 1 :
2059- raise ValueError (
2060- "jacobian expects a 1 dimensional variable as `expression`."
2061- " If not use flatten to make it a vector"
2062- )
2063-
20642066 using_list = isinstance (wrt , list )
20652067 using_tuple = isinstance (wrt , tuple )
2068+ grad_kwargs = {
2069+ "consider_constant" : consider_constant ,
2070+ "disconnected_inputs" : disconnected_inputs ,
2071+ }
20662072
20672073 if isinstance (wrt , list | tuple ):
20682074 wrt = list (wrt )
20692075 else :
20702076 wrt = [wrt ]
20712077
20722078 if all (expression .type .broadcastable ):
2073- # expression is just a scalar, use grad
2074- return as_list_or_tuple (
2075- using_list ,
2076- using_tuple ,
2077- grad (
2078- expression .squeeze (),
2079- wrt ,
2080- consider_constant = consider_constant ,
2081- disconnected_inputs = disconnected_inputs ,
2082- ),
2079+ jacobian_matrices = grad (expression .squeeze (), wrt , ** grad_kwargs )
2080+
2081+ elif vectorize :
2082+ expression_flat = expression .ravel ()
2083+ row_tangent = _float_ones_like (expression_flat ).type ("row_tangent" )
2084+ jacobian_single_rows = Lop (expression .ravel (), wrt , row_tangent , ** grad_kwargs )
2085+
2086+ n_rows = expression_flat .size
2087+ jacobian_matrices = vectorize_graph (
2088+ jacobian_single_rows ,
2089+ replace = {row_tangent : eye (n_rows , dtype = row_tangent .dtype )},
20832090 )
2091+ if disconnected_inputs != "raise" :
2092+ # If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian
2093+ # We have to broadcast the zeros explicitly here
2094+ for i , (jacobian_single_row , jacobian_matrix ) in enumerate (
2095+ zip (jacobian_single_rows , jacobian_matrices , strict = True )
2096+ ):
2097+ if jacobian_single_row .ndim == jacobian_matrix .ndim :
2098+ jacobian_matrices [i ] = broadcast_to (
2099+ jacobian_matrix , shape = (n_rows , * jacobian_matrix .shape )
2100+ )
20842101
2085- def inner_function (* args ):
2086- idx = args [0 ]
2087- expr = args [1 ]
2088- rvals = []
2089- for inp in args [2 :]:
2090- rval = grad (
2091- expr [idx ],
2092- inp ,
2093- consider_constant = consider_constant ,
2094- disconnected_inputs = disconnected_inputs ,
2102+ else :
2103+
2104+ def inner_function (* args ):
2105+ idx , expr , * wrt = args
2106+ return grad (expr [idx ], wrt , ** grad_kwargs )
2107+
2108+ jacobian_matrices , updates = pytensor .scan (
2109+ inner_function ,
2110+ sequences = pytensor .tensor .arange (expression .size ),
2111+ non_sequences = [expression .ravel (), * wrt ],
2112+ return_list = True ,
2113+ )
2114+ if updates :
2115+ raise ValueError (
2116+ "The scan used to build the jacobian matrices returned a list of updates"
20952117 )
2096- rvals .append (rval )
2097- return rvals
2098-
2099- # Computing the gradients does not affect the random seeds on any random
2100- # generator used n expression (because during computing gradients we are
2101- # just backtracking over old values. (rp Jan 2012 - if anyone has a
2102- # counter example please show me)
2103- jacobs , updates = pytensor .scan (
2104- inner_function ,
2105- sequences = pytensor .tensor .arange (expression .shape [0 ]),
2106- non_sequences = [expression , * wrt ],
2107- )
2108- assert not updates , "Scan has returned a list of updates; this should not happen."
2109- return as_list_or_tuple (using_list , using_tuple , jacobs )
2118+
2119+ if jacobian_matrices [0 ].ndim < (expression .ndim + wrt [0 ].ndim ):
2120+ # There was some raveling or squeezing done prior to getting the jacobians
2121+ # Reshape into original shapes
2122+ jacobian_matrices = [
2123+ jac_matrix .reshape ((* expression .shape , * w .shape ))
2124+ for jac_matrix , w in zip (jacobian_matrices , wrt , strict = True )
2125+ ]
2126+
2127+ return as_list_or_tuple (using_list , using_tuple , jacobian_matrices )
21102128
21112129
21122130def hessian (cost , wrt , consider_constant = None , disconnected_inputs = "raise" ):
0 commit comments