@@ -139,10 +139,14 @@ def _get_parameter_grads_from_vector(
139139 Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
140140 returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
141141 """
142+ grad_wrt_args_vector = cast (TensorVariable , grad_wrt_args_vector )
143+ x_star = cast (TensorVariable , x_star )
144+
142145 cursor = 0
143146 grad_wrt_args = []
144147
145148 for arg in args :
149+ arg = cast (TensorVariable , arg )
146150 arg_shape = arg .shape
147151 arg_size = arg_shape .prod ()
148152 arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
@@ -233,16 +237,17 @@ def scalar_implict_optimization_grads(
233237 output_grad : Variable ,
234238 fgraph : FunctionGraph ,
235239) -> list [Variable ]:
236- df_dx , * df_dthetas = grad (
237- inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
240+ df_dx , * df_dthetas = cast (
241+ list [Variable ],
242+ grad (inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore" ),
238243 )
239244
240245 replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
241246 df_dx_star , * df_dthetas_stars = graph_replace ([df_dx , * df_dthetas ], replace = replace )
242247
243248 grad_wrt_args = [
244249 (- df_dtheta_star / df_dx_star ) * output_grad
245- for df_dtheta_star in df_dthetas_stars
250+ for df_dtheta_star in cast ( list [ TensorVariable ], df_dthetas_stars )
246251 ]
247252
248253 return grad_wrt_args
@@ -297,15 +302,21 @@ def implict_optimization_grads(
297302 fgraph : FunctionGraph
298303 The function graph that contains the inputs and outputs of the optimization problem.
299304 """
305+ df_dx = cast (TensorVariable , df_dx )
306+
300307 df_dtheta = concatenate (
301- [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
308+ [
309+ atleast_2d (jac_col , left = False )
310+ for jac_col in cast (list [TensorVariable ], df_dtheta_columns )
311+ ],
302312 axis = - 1 ,
303313 )
304314
305315 replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
306316
307- df_dx_star , df_dtheta_star = graph_replace (
308- [atleast_2d (df_dx ), df_dtheta ], replace = replace
317+ df_dx_star , df_dtheta_star = cast (
318+ list [TensorVariable ],
319+ graph_replace ([atleast_2d (df_dx ), df_dtheta ], replace = replace ),
309320 )
310321
311322 grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
@@ -546,7 +557,9 @@ def __init__(
546557 self .fgraph = FunctionGraph ([variables , * args ], [equation ])
547558
548559 if jac :
549- f_prime = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
560+ f_prime = cast (
561+ Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
562+ )
550563 self .fgraph .add_output (f_prime )
551564
552565 if hess :
@@ -555,7 +568,9 @@ def __init__(
555568 "Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
556569 " using first derivatives."
557570 )
558- f_double_prime = grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
571+ f_double_prime = cast (
572+ Variable , grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
573+ )
559574 self .fgraph .add_output (f_double_prime )
560575
561576 self .method = method
0 commit comments