Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 1, 2026

Description

Allows optimize.minimize and optimize.root to be called with multiple inputs of arbitrary shapes, as in:

    x0, x1, x2 = pt.dvectors("x1", "x2", "d3")
    x3 = pt.dmatrix("x3")
    b0, b1, b2 = pt.dscalars("b0", "b1", "b2")
    b3 = pt.dvector("b3")

    y = pt.dvector("y")

    y_hat = x0 * b0 + x1 * b1 + x2 * b2 + x3 @ b3
    objective = ((y - y_hat) ** 2).sum()

    minimized_x, success = minimize(
        objective,
        [b0, b1, b2, b3],
        jac=True,
        hess=True,
        method="Newton-CG",
        use_vectorized_jac=True,
    )

Internally, pack and unpack are used to convert the problem to 1d and return results in the same shape as the inputs.

pack and unpack are also used in the gradients, to simplify the construction of the jacobian with respect to arguments to the optimization function. We should consider simply using pack/unpack in the jacobian function itself, and add an option to get back the unpacked form (what we currently give back -- the columns of the jacobian matrix) or the backed form (a single matrix).

Tests are failing because of a bug in scan, I'm going to have to beg @ricardoV94 to help me understand how to fix that.

This PR also adds L_op implementations for SplitDims and JoinDims. I found this was easier than constantly rewriting the graph to try to remove these ops. Their pullbacks are also SplitDims and JoinDims, so in the end the gradients will be rewritten into Reshape as well, so I don't see any harm.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski jessegrabowski added bug Something isn't working enhancement New feature or request feature request SciPy compatibility labels Jan 1, 2026
@jessegrabowski
Copy link
Member Author

The scan bug happens here, because shape information is being destroyed somewhere. If I comment out that check, all tests pass.

@jessegrabowski jessegrabowski force-pushed the optimize-use-pack branch 2 times, most recently from 2fd5ae0 to be39ef6 Compare January 1, 2026 04:12
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 1, 2026

Regarding the notebook error reported in #1586, the notebooks runs now but with rewrite warnings. The specific rewrite has to do with squeeze, but it is arising because of the use of vectorize_graph on the gradients of root.

We potentially ought to rewrite to scalar_minimize in that case, but scipy.minimize handles this case gracefully and we should too.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances optimize.minimize and optimize.root to accept multiple input variables of arbitrary shapes, addressing issues #1550, #1465, and #1586. The implementation uses pack and unpack operations to handle multiple variables by flattening them into a single vector for scipy optimization, then reshaping results back to their original forms.

Key Changes

  • Added pack/unpack support to minimize and root for handling multiple variables of different shapes
  • Implemented L_op (gradient) methods for JoinDims and SplitDims ops to support autodiff through pack/unpack operations
  • Refactored implict_optimization_grads to use packed variables internally, simplifying jacobian construction

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.

File Description
pytensor/tensor/optimize.py Added _maybe_pack_input_variables_and_rewrite_objective function; updated minimize and root signatures to accept sequences; refactored gradient computation to use packed variables
pytensor/tensor/reshape.py Added L_op implementations for JoinDims and SplitDims; added connection_pattern to SplitDims; improved unpack to handle single-element lists without splitting
tests/tensor/test_optimize.py Added comprehensive tests for multiple minimands, MVN logp gradient regression, multiple root inputs, and vectorized root gradients
tests/tensor/test_reshape.py Added gradient verification tests for join_dims and split_dims; changed rewrite pass from "specialize" to "canonicalize"

@jessegrabowski jessegrabowski force-pushed the optimize-use-pack branch 2 times, most recently from 66cf591 to 3e8a3f3 Compare January 1, 2026 21:07
@ricardoV94
Copy link
Member

ricardoV94 commented Jan 2, 2026

I'm not sure about moving the complexity of handling multiple inputs to our op helpers. Is it that hard to ask users to use pack/unpack themselves. This way the PR is also harder to review as you're doing new feature and bugfix together, and the changes are by no means trivial.

In doing so, you're also moving quite away from the scipy API so it will be less obvious how these work

@jessegrabowski
Copy link
Member Author

I want this API as the front end. We're already far from the scipy API -- we don't take jac or hess functions, and we don't take args. The use-defined single vector case is still valid, so this is a strict upgrade with 100% backwards compat.

I was quite careful to keep the commits split by change, it's difficult to review commit by commit? I am willing to split them into separate PRs if you insist.

@ricardoV94
Copy link
Member

It's not hard to review commit by commit, but I have less trust that the bugfix commit fixed the bug, and not the API change logic for instance

@jessegrabowski
Copy link
Member Author

Sure I can split it out then. I can also address #1466 in a single bugfix PR then circle back to this.

@jessegrabowski
Copy link
Member Author

On further consideration, the other two bugs are directly addressed by this PR. I split the gradients out, since those are different.

Both #1550 #1586 are reporting the same bug. Root/minimize currently fail when computing gradients with respect args with ndim > 2. This PR will handle that natively by using pack/unpack. Specifically, this line assumes that the return from jacobian is always <2d, which isn't the case in general.

An intermediate PR would have to ravel all the args and do book-keeping on their shapes, which is exactly what pack/unpack are for. So I don't see anything to split out, except the L_ops, which I already did.

@jessegrabowski
Copy link
Member Author

I got rid of the eager graph_rewrite, which fixed the scan bug I was hitting.

As a result, I had to implement some logic to filter non-differentiable args, which was a bug you had previously hit. The disconnected_inputs='raise' case works now, so I adjusted the test.

@ricardoV94
Copy link
Member

Sounds like nice progress. If you're just splitting the lop, don't bother, those are simple enough

@jessegrabowski
Copy link
Member Author

Sounds like nice progress. If you're just splitting the lop, don't bother, those are simple enough

Too late. You're already tagged too.

@ricardoV94
Copy link
Member

Sounds like nice progress. If you're just splitting the lop, don't bother, those are simple enough

Too late. You're already tagged too.

You better not have used reshape

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2026

Shape of unpacked inputs (and indices) and most cases of integers will show up as disconnected correctly. I'm saying just not every single integer input is of that nature.

Were you actually seeing integers used in unpack showing up as connected? That would mean connection_pattern helper isn't working as expected. You shouldn't need an extra manual filter.

Edit: That was indeed the problem, see next comment

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2026

Okay, so the confusion for me comes from Split Op not considering the gradient disconnected wrt to the split_size argument, which is puzzling (fixed in #1828). I'm opening an issue to discuss this #1827

The bigger issue is that io_connection_pattern doesn't give us all we want. It doesn't inform about the null gradients that are still "connected", whatever those mean (if they mean anything). See my revised proposal in the next comment.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2026

@jessegrabowski This is a summary of my current understanding of what Minimize.L_op should do:

  1. Implement connection_pattern, because pytensor.grad will be annoyed if you return disconnected_grad without telling it in advance, but don't rely on it in L_op because it's not sufficient.
  2. Call gradient in the inner graph, asking it to return null types, and keep track of which are disconnected_grad/undefined_grad/not_implemented_grad. These can be removed from the internal jacobian, and the L_op, should just remember what it is to preserve the meaning of disconnected / undefined / not_implemented at the end. I think this is pretty much what Scan is doing in the snippet below.
  3. If by any-chance there's a non-numerical input for which you get a non-null gradient back, still exclude it from the internal jacobian and return grad_not_implemented for this input. Supposedly it is differentiable but you just don't know how to handle it alongside regular TensorVariables. Think about SparseVariables (if they aren't supported yet), or a newer type you haven't seen that is also differentiable.
  4. Do not worry whether the remaining numerical inputs are integers or floats, the gradient should make sense as per current PyTensor API. We can debate that, I don't care much about it, but I don't feel it has to be done in this specific PR and for now we should remain consistent.
  5. If you still see variables that you feel shouldn't be differentiable wrt, this is more a discussion akin to Confusion between grad_undefined / grad_disconnected #1827 or Fix issues with split and split_dims #1828, and not Minimize specifically.

This approach would rule out the shape variables used in Split regardless of #1828, as they would still be linked to grad_undefined.

Does that make sense?

pytensor/pytensor/scan/op.py

Lines 2547 to 2555 in 79a4bc1

grads = grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
disconnected_inputs="ignore",
return_disconnected="None",
null_gradients="return",
)

pytensor/pytensor/scan/op.py

Lines 3086 to 3110 in 79a4bc1

if t == "connected":
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if info.as_while:
n_zeros = inputs[0] - n_steps
shp = (n_zeros,)
if x.ndim > 1:
shp = shp + tuple(x.shape[i] for i in range(1, x.ndim))
z = pt.zeros(shp, dtype=x.dtype)
x = pt.concatenate([x[::-1], z], axis=0)
gradients.append(x)
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
elif t == "through_untraced":
gradients.append(
grad_undefined(
self, p + 1, inputs[p + 1], "Depends on a untraced variable"
)
)
else:
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())

@jessegrabowski
Copy link
Member Author

Tests are passing with #1806

I ended up having to use grad in the connection_pattern, because io_connection_pattern wasn't cutting it. Maybe you had something else in mind, but I hope not.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 9, 2026

I ended up having to use grad in the connection_pattern, because io_connection_pattern wasn't cutting it. Maybe you had something else in mind, but I hope not.

You shouldn't have to? As long as you are not returing disconnected that were not in connection_pattern I think that's fine. You may be replacing nulltypes with disconnected accidentally?

And at most is should just result in a warning from PyTensor

)

with pytest.raises(NullTypeGradError):
with pytest.raises(DisconnectedInputError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this why io_connection_pattern didn't suffice? It seems you converted NullType grads to Disconnected, beyond what would be implied by the connection pattern. It doesn't bother me too much, but could mean you can simplify that Op method to not call grad

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, see my comment below. My only other thought is that there's a bug upstream where the indexing grads are DisconnectedType but they should be NullType

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, just some minor questions / outdated comments

[atleast_2d(df_dx), df_dtheta], replace=replace
)
if arg_grad is None:
final_grads.append(DisconnectedType()())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that if the original grad was NullType, you should return that, unless it would also be disconnected even if it wasn't Null

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a biggie, but maybe why you needed the call to grad in connection_pattern method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My code here is incorrect, but its not the reason why i needed to call grad in the connection_pattern. For example, in the test_optimize_multiple_minimands test case, the (outer) args have the following types:

[(ExpandDims{axis=0}.0, TensorType(int8, shape=(1,))), 
(Subtensor{start:stop}.0, TensorType(int64, shape=(1,))), 
(Prod{axes=None}.0, TensorType(int64, shape=())), 
(Prod{axes=None}.0, TensorType(int64, shape=())),
 (Prod{axes=None}.0, TensorType(int64, shape=())), 
(input 7, TensorType(float64, shape=(100, 5))), 
(Subtensor{start:stop}.0, TensorType(int64, shape=(0,))), 
(input 6, TensorType(float64, shape=(100,))), 
(Subtensor{start:stop}.0, TensorType(int64, shape=(0,))), 
(input 5, TensorType(float64, shape=(100,))), 
(Subtensor{start:stop}.0, TensorType(int64, shape=(0,))), 
(input 4, TensorType(float64, shape=(100,))), 
(input 8, TensorType(float64, shape=(100,)))]

Here is the connection pattern generated by io_connection_pattern:

[[True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False], [True, False]]

And here are the gradients of the inner function:

[Squeeze{axis=0}.0, <DisconnectedType>, <DisconnectedType>, <DisconnectedType>, <DisconnectedType>, Reshape{2}.0, <DisconnectedType>, Squeeze{axis=0}.0, <DisconnectedType>, Squeeze{axis=0}.0, <DisconnectedType>, Squeeze{axis=0}.0, Squeeze{axis=0}.0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a look in the debugger

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request feature request SciPy compatibility

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient of MinimizeOp fails with certain parameter shapes

2 participants