@@ -447,8 +447,8 @@ def compute_implicit_gradients(
447447 )
448448 grad_wrt_args = unpack (
449449 grad_wrt_args_packed ,
450- packed_shapes = packed_shapes_from_outer ,
451- axes = 0 if not all (inp .ndim == 0 for inp in (x_star , * args )) else None ,
450+ packed_shapes_from_outer ,
451+ keep_axes = None if all (inp .ndim == 0 for inp in (x_star , * args )) else 0 ,
452452 )
453453 else :
454454 grad_wrt_args = [grad_wrt_args_packed ]
@@ -702,8 +702,8 @@ def pack_inputs_of_objective(
702702 elif len (x ) == 1 :
703703 packed_input = x [0 ]
704704 else :
705- packed_input , packed_shapes = pack (* x , axes = None )
706- unpacked_output = unpack (packed_input , axes = None , packed_shapes = packed_shapes )
705+ packed_input , packed_shapes = pack (* x )
706+ unpacked_output = unpack (packed_input , packed_shapes )
707707
708708 objective = graph_replace (
709709 objective ,
@@ -781,7 +781,7 @@ def minimize(
781781 solution , success = minimize_op (packed_input , * args )
782782
783783 if packed_shapes is not None :
784- solution = unpack (solution , axes = None , packed_shapes = packed_shapes )
784+ solution = unpack (solution , packed_shapes )
785785
786786 return solution , success
787787
@@ -1109,7 +1109,7 @@ def root(
11091109
11101110 solution , success = root_op (packed_variables , * args )
11111111 if packed_shapes is not None :
1112- solution = unpack (solution , axes = None , packed_shapes = packed_shapes )
1112+ solution = unpack (solution , packed_shapes )
11131113
11141114 return solution , success
11151115
0 commit comments