@@ -213,7 +213,7 @@ def Rop(
213
213
214
214
# Check that each element of wrt corresponds to an element
215
215
# of eval_points with the same dimensionality.
216
- for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points )):
216
+ for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points , strict = True )):
217
217
try :
218
218
if wrt_elem .type .ndim != eval_point .type .ndim :
219
219
raise ValueError (
@@ -262,7 +262,7 @@ def _traverse(node):
262
262
seen_nodes [inp .owner ][inp .owner .outputs .index (inp )]
263
263
)
264
264
same_type_eval_points = []
265
- for x , y in zip (inputs , local_eval_points ):
265
+ for x , y in zip (inputs , local_eval_points , strict = True ):
266
266
if y is not None :
267
267
if not isinstance (x , Variable ):
268
268
x = pytensor .tensor .as_tensor_variable (x )
@@ -399,7 +399,7 @@ def Lop(
399
399
_wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
400
400
401
401
assert len (_f ) == len (grads )
402
- known = dict (zip (_f , grads ))
402
+ known = dict (zip (_f , grads , strict = True ))
403
403
404
404
ret = grad (
405
405
cost = None ,
@@ -778,7 +778,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
778
778
for i in range (len (grads )):
779
779
grads [i ] += cost_grads [i ]
780
780
781
- pgrads = dict (zip (params , grads ))
781
+ pgrads = dict (zip (params , grads , strict = True ))
782
782
# separate wrt from end grads:
783
783
wrt_grads = [pgrads [k ] for k in wrt ]
784
784
end_grads = [pgrads [k ] for k in end ]
@@ -1045,7 +1045,7 @@ def access_term_cache(node):
1045
1045
in [
1046
1046
input_to_output and output_to_cost
1047
1047
for input_to_output , output_to_cost in zip (
1048
- input_to_outputs , outputs_connected
1048
+ input_to_outputs , outputs_connected , strict = True
1049
1049
)
1050
1050
]
1051
1051
)
@@ -1071,7 +1071,7 @@ def access_term_cache(node):
1071
1071
not in [
1072
1072
in_to_out and out_to_cost and not out_nan
1073
1073
for in_to_out , out_to_cost , out_nan in zip (
1074
- in_to_outs , outputs_connected , ograd_is_nan
1074
+ in_to_outs , outputs_connected , ograd_is_nan , strict = True
1075
1075
)
1076
1076
]
1077
1077
)
@@ -1131,7 +1131,7 @@ def try_to_copy_if_needed(var):
1131
1131
# DO NOT force integer variables to have integer dtype.
1132
1132
# This is a violation of the op contract.
1133
1133
new_output_grads = []
1134
- for o , og in zip (node .outputs , output_grads ):
1134
+ for o , og in zip (node .outputs , output_grads , strict = True ):
1135
1135
o_dt = getattr (o .type , "dtype" , None )
1136
1136
og_dt = getattr (og .type , "dtype" , None )
1137
1137
if (
@@ -1145,7 +1145,7 @@ def try_to_copy_if_needed(var):
1145
1145
1146
1146
# Make sure that, if new_output_grads[i] has a floating point
1147
1147
# dtype, it is the same dtype as outputs[i]
1148
- for o , ng in zip (node .outputs , new_output_grads ):
1148
+ for o , ng in zip (node .outputs , new_output_grads , strict = True ):
1149
1149
o_dt = getattr (o .type , "dtype" , None )
1150
1150
ng_dt = getattr (ng .type , "dtype" , None )
1151
1151
if (
@@ -1167,7 +1167,9 @@ def try_to_copy_if_needed(var):
1167
1167
# by the user, not computed by Op.grad, and some gradients are
1168
1168
# only computed and returned, but never passed as another
1169
1169
# node's output grads.
1170
- for idx , packed in enumerate (zip (node .outputs , new_output_grads )):
1170
+ for idx , packed in enumerate (
1171
+ zip (node .outputs , new_output_grads , strict = True )
1172
+ ):
1171
1173
orig_output , new_output_grad = packed
1172
1174
if not hasattr (orig_output , "shape" ):
1173
1175
continue
@@ -1233,7 +1235,7 @@ def try_to_copy_if_needed(var):
1233
1235
not in [
1234
1236
in_to_out and out_to_cost and not out_int
1235
1237
for in_to_out , out_to_cost , out_int in zip (
1236
- in_to_outs , outputs_connected , output_is_int
1238
+ in_to_outs , outputs_connected , output_is_int , strict = True
1237
1239
)
1238
1240
]
1239
1241
)
@@ -1314,7 +1316,7 @@ def try_to_copy_if_needed(var):
1314
1316
# Check that op.connection_pattern matches the connectivity
1315
1317
# logic driving the op.grad method
1316
1318
for i , (ipt , ig , connected ) in enumerate (
1317
- zip (inputs , input_grads , inputs_connected )
1319
+ zip (inputs , input_grads , inputs_connected , strict = True )
1318
1320
):
1319
1321
actually_connected = not isinstance (ig .type , DisconnectedType )
1320
1322
@@ -1601,7 +1603,7 @@ def abs_rel_errors(self, g_pt):
1601
1603
if len (g_pt ) != len (self .gf ):
1602
1604
raise ValueError ("argument has wrong number of elements" , len (g_pt ))
1603
1605
errs = []
1604
- for i , (a , b ) in enumerate (zip (g_pt , self .gf )):
1606
+ for i , (a , b ) in enumerate (zip (g_pt , self .gf , strict = True )):
1605
1607
if a .shape != b .shape :
1606
1608
raise ValueError (
1607
1609
f"argument element { i } has wrong shapes { a .shape } , { b .shape } "
0 commit comments