Skip to content

Commit 17d709e

Browse files
committed
Simplify pack/unpack calls
1 parent 5f96701 commit 17d709e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pytensor/tensor/optimize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ def compute_implicit_gradients(
447447
)
448448
grad_wrt_args = unpack(
449449
grad_wrt_args_packed,
450-
packed_shapes=packed_shapes_from_outer,
451-
axes=0 if not all(inp.ndim == 0 for inp in (x_star, *args)) else None,
450+
packed_shapes_from_outer,
451+
keep_axes=None if all(inp.ndim == 0 for inp in (x_star, *args)) else 0,
452452
)
453453
else:
454454
grad_wrt_args = [grad_wrt_args_packed]
@@ -702,8 +702,8 @@ def pack_inputs_of_objective(
702702
elif len(x) == 1:
703703
packed_input = x[0]
704704
else:
705-
packed_input, packed_shapes = pack(*x, axes=None)
706-
unpacked_output = unpack(packed_input, axes=None, packed_shapes=packed_shapes)
705+
packed_input, packed_shapes = pack(*x)
706+
unpacked_output = unpack(packed_input, packed_shapes)
707707

708708
objective = graph_replace(
709709
objective,
@@ -781,7 +781,7 @@ def minimize(
781781
solution, success = minimize_op(packed_input, *args)
782782

783783
if packed_shapes is not None:
784-
solution = unpack(solution, axes=None, packed_shapes=packed_shapes)
784+
solution = unpack(solution, packed_shapes)
785785

786786
return solution, success
787787

@@ -1109,7 +1109,7 @@ def root(
11091109

11101110
solution, success = root_op(packed_variables, *args)
11111111
if packed_shapes is not None:
1112-
solution = unpack(solution, axes=None, packed_shapes=packed_shapes)
1112+
solution = unpack(solution, packed_shapes)
11131113

11141114
return solution, success
11151115

0 commit comments

Comments
 (0)