@@ -213,7 +213,7 @@ def Rop(
213
213
214
214
# Check that each element of wrt corresponds to an element
215
215
# of eval_points with the same dimensionality.
216
- for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points )):
216
+ for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points , strict = True )):
217
217
try :
218
218
if wrt_elem .type .ndim != eval_point .type .ndim :
219
219
raise ValueError (
@@ -262,7 +262,7 @@ def _traverse(node):
262
262
seen_nodes [inp .owner ][inp .owner .outputs .index (inp )]
263
263
)
264
264
same_type_eval_points = []
265
- for x , y in zip (inputs , local_eval_points ):
265
+ for x , y in zip (inputs , local_eval_points , strict = True ):
266
266
if y is not None :
267
267
if not isinstance (x , Variable ):
268
268
x = pytensor .tensor .as_tensor_variable (x )
@@ -399,7 +399,7 @@ def Lop(
399
399
_wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
400
400
401
401
assert len (_f ) == len (grads )
402
- known = dict (zip (_f , grads ))
402
+ known = dict (zip (_f , grads , strict = True ))
403
403
404
404
ret = grad (
405
405
cost = None ,
@@ -778,7 +778,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
778
778
for i in range (len (grads )):
779
779
grads [i ] += cost_grads [i ]
780
780
781
- pgrads = dict (zip (params , grads ))
781
+ pgrads = dict (zip (params , grads , strict = True ))
782
782
# separate wrt from end grads:
783
783
wrt_grads = [pgrads [k ] for k in wrt ]
784
784
end_grads = [pgrads [k ] for k in end ]
@@ -1044,7 +1044,7 @@ def access_term_cache(node):
1044
1044
any (
1045
1045
input_to_output and output_to_cost
1046
1046
for input_to_output , output_to_cost in zip (
1047
- input_to_outputs , outputs_connected
1047
+ input_to_outputs , outputs_connected , strict = True
1048
1048
)
1049
1049
)
1050
1050
)
@@ -1069,7 +1069,7 @@ def access_term_cache(node):
1069
1069
not any (
1070
1070
in_to_out and out_to_cost and not out_nan
1071
1071
for in_to_out , out_to_cost , out_nan in zip (
1072
- in_to_outs , outputs_connected , ograd_is_nan
1072
+ in_to_outs , outputs_connected , ograd_is_nan , strict = True
1073
1073
)
1074
1074
)
1075
1075
)
@@ -1129,7 +1129,7 @@ def try_to_copy_if_needed(var):
1129
1129
# DO NOT force integer variables to have integer dtype.
1130
1130
# This is a violation of the op contract.
1131
1131
new_output_grads = []
1132
- for o , og in zip (node .outputs , output_grads ):
1132
+ for o , og in zip (node .outputs , output_grads , strict = True ):
1133
1133
o_dt = getattr (o .type , "dtype" , None )
1134
1134
og_dt = getattr (og .type , "dtype" , None )
1135
1135
if (
@@ -1143,7 +1143,7 @@ def try_to_copy_if_needed(var):
1143
1143
1144
1144
# Make sure that, if new_output_grads[i] has a floating point
1145
1145
# dtype, it is the same dtype as outputs[i]
1146
- for o , ng in zip (node .outputs , new_output_grads ):
1146
+ for o , ng in zip (node .outputs , new_output_grads , strict = True ):
1147
1147
o_dt = getattr (o .type , "dtype" , None )
1148
1148
ng_dt = getattr (ng .type , "dtype" , None )
1149
1149
if (
@@ -1165,7 +1165,9 @@ def try_to_copy_if_needed(var):
1165
1165
# by the user, not computed by Op.grad, and some gradients are
1166
1166
# only computed and returned, but never passed as another
1167
1167
# node's output grads.
1168
- for idx , packed in enumerate (zip (node .outputs , new_output_grads )):
1168
+ for idx , packed in enumerate (
1169
+ zip (node .outputs , new_output_grads , strict = True )
1170
+ ):
1169
1171
orig_output , new_output_grad = packed
1170
1172
if not hasattr (orig_output , "shape" ):
1171
1173
continue
@@ -1231,7 +1233,7 @@ def try_to_copy_if_needed(var):
1231
1233
not in [
1232
1234
in_to_out and out_to_cost and not out_int
1233
1235
for in_to_out , out_to_cost , out_int in zip (
1234
- in_to_outs , outputs_connected , output_is_int
1236
+ in_to_outs , outputs_connected , output_is_int , strict = True
1235
1237
)
1236
1238
]
1237
1239
)
@@ -1312,7 +1314,7 @@ def try_to_copy_if_needed(var):
1312
1314
# Check that op.connection_pattern matches the connectivity
1313
1315
# logic driving the op.grad method
1314
1316
for i , (ipt , ig , connected ) in enumerate (
1315
- zip (inputs , input_grads , inputs_connected )
1317
+ zip (inputs , input_grads , inputs_connected , strict = True )
1316
1318
):
1317
1319
actually_connected = not isinstance (ig .type , DisconnectedType )
1318
1320
@@ -1599,7 +1601,7 @@ def abs_rel_errors(self, g_pt):
1599
1601
if len (g_pt ) != len (self .gf ):
1600
1602
raise ValueError ("argument has wrong number of elements" , len (g_pt ))
1601
1603
errs = []
1602
- for i , (a , b ) in enumerate (zip (g_pt , self .gf )):
1604
+ for i , (a , b ) in enumerate (zip (g_pt , self .gf , strict = True )):
1603
1605
if a .shape != b .shape :
1604
1606
raise ValueError (
1605
1607
f"argument element { i } has wrong shapes { a .shape } , { b .shape } "
0 commit comments