Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,11 @@ def c_code_cache_version(self):
tensor_from_scalar = TensorFromScalar()


@_vectorize_node.register(TensorFromScalar)
def vectorize_tensor_from_scalar(op, node, batch_x):
return identity(batch_x).owner


class ScalarFromTensor(COp):
__props__ = ()

Expand Down Expand Up @@ -2046,6 +2051,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = Elemwise(ps.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter())
identity = tensor_copy


class Default(Op):
Expand Down Expand Up @@ -4603,6 +4609,7 @@ def ix_(*args):
"matrix_transpose",
"default",
"tensor_copy",
"identity",
"transfer",
"alloc",
"identity_like",
Expand Down
60 changes: 46 additions & 14 deletions pytensor/tensor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytensor.scalar as ps
from pytensor.compile.function import function
from pytensor.gradient import grad, hessian, jacobian
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
Expand Down Expand Up @@ -484,6 +484,7 @@ def __init__(
jac: bool = True,
hess: bool = False,
hessp: bool = False,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None,
):
if not cast(TensorVariable, objective).ndim == 0:
Expand All @@ -496,6 +497,7 @@ def __init__(
)

self.fgraph = FunctionGraph([x, *args], [objective])
self.use_vectorized_jac = use_vectorized_jac

if jac:
grad_wrt_x = cast(
Expand All @@ -505,7 +507,12 @@ def __init__(

if hess:
hess_wrt_x = cast(
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0])
Variable,
jacobian(
self.fgraph.outputs[-1],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
),
)
self.fgraph.add_output(hess_wrt_x)

Expand Down Expand Up @@ -561,7 +568,10 @@ def L_op(self, inputs, outputs, output_grads):
implicit_f = grad(inner_fx, inner_x)

df_dx, *df_dtheta_columns = jacobian(
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
implicit_f,
[inner_x, *inner_args],
disconnected_inputs="ignore",
vectorize=self.use_vectorized_jac,
)
grad_wrt_args = implict_optimization_grads(
df_dx=df_dx,
Expand All @@ -581,6 +591,7 @@ def minimize(
method: str = "BFGS",
jac: bool = True,
hess: bool = False,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None,
) -> tuple[TensorVariable, TensorVariable]:
"""
Expand All @@ -590,18 +601,21 @@ def minimize(
----------
objective : TensorVariable
The objective function to minimize. This should be a pytensor variable representing a scalar value.

x : TensorVariable
x: TensorVariable
The variable with respect to which the objective function is minimized. It must be an input to the
computational graph of `objective`.

method : str, optional
method: str, optional
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.

jac : bool, optional
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
jac: bool, optional
Whether to compute and use the gradient of the objective function with respect to x for optimization.
Default is True.

hess: bool, optional
Whether to compute and use the Hessian of the objective function with respect to x for optimization.
Default is False. Note that some methods require this, while others do not support it.
use_vectorized_jac: bool, optional
Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
but use more memory. Default is False.
optimizer_kwargs
Additional keyword arguments to pass to scipy.optimize.minimize

Expand All @@ -624,6 +638,7 @@ def minimize(
method=method,
jac=jac,
hess=hess,
use_vectorized_jac=use_vectorized_jac,
optimizer_kwargs=optimizer_kwargs,
)

Expand Down Expand Up @@ -804,6 +819,7 @@ def __init__(
method: str = "hybr",
jac: bool = True,
optimizer_kwargs: dict | None = None,
use_vectorized_jac: bool = False,
):
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
raise ValueError(
Expand All @@ -817,7 +833,11 @@ def __init__(
self.fgraph = FunctionGraph([variables, *args], [equations])

if jac:
jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0])
jac_wrt_x = jacobian(
self.fgraph.outputs[0],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
)
self.fgraph.add_output(atleast_2d(jac_wrt_x))

self.jac = jac
Expand Down Expand Up @@ -897,8 +917,14 @@ def L_op(
inner_x, *inner_args = self.fgraph.inputs
inner_fx = self.fgraph.outputs[0]

df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1]
df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore")
df_dx = (
jacobian(inner_fx, inner_x, vectorize=True)
if not self.jac
else self.fgraph.outputs[1]
)
df_dtheta_columns = jacobian(
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
)

grad_wrt_args = implict_optimization_grads(
df_dx=df_dx,
Expand All @@ -917,6 +943,7 @@ def root(
variables: TensorVariable,
method: str = "hybr",
jac: bool = True,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None,
) -> tuple[TensorVariable, TensorVariable]:
"""
Expand All @@ -935,6 +962,10 @@ def root(
jac : bool, optional
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
Default is True. Most methods require this.
use_vectorized_jac: bool, optional
Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
Default is False.
optimizer_kwargs : dict, optional
Additional keyword arguments to pass to `scipy.optimize.root`.

Expand All @@ -958,6 +989,7 @@ def root(
method=method,
jac=jac,
optimizer_kwargs=optimizer_kwargs,
use_vectorized_jac=use_vectorized_jac,
)

solution, success = cast(
Expand Down