@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
171
171
)
172
172
size = min (v1 .type .ndim , v2 .type .ndim )
173
173
for n , (b1 , b2 ) in enumerate (
174
- zip (v1 .type .broadcastable [- size :], v2 .type .broadcastable [- size :])
174
+ zip (v1 .type .broadcastable [- size :], v2 .type .broadcastable [- size :], strict = False )
175
175
):
176
176
if b1 != b2 :
177
177
a1 = n + size - v1 .type .ndim + 1
@@ -578,6 +578,7 @@ def get_oinp_iinp_iout_oout_mappings(self):
578
578
inner_input_indices ,
579
579
inner_output_indices ,
580
580
outer_output_indices ,
581
+ strict = True ,
581
582
):
582
583
if oout != - 1 :
583
584
mappings ["outer_inp_from_outer_out" ][oout ] = oinp
@@ -959,7 +960,7 @@ def make_node(self, *inputs):
959
960
# them have the same dtype
960
961
argoffset = 0
961
962
for inner_seq , outer_seq in zip (
962
- self .inner_seqs (self .inner_inputs ), self .outer_seqs (inputs )
963
+ self .inner_seqs (self .inner_inputs ), self .outer_seqs (inputs ), strict = True
963
964
):
964
965
check_broadcast (outer_seq , inner_seq )
965
966
new_inputs .append (copy_var_format (outer_seq , as_var = inner_seq ))
@@ -978,6 +979,7 @@ def make_node(self, *inputs):
978
979
self .info .mit_mot_in_slices ,
979
980
self .info .mit_mot_out_slices [: self .info .n_mit_mot ],
980
981
self .outer_mitmot (inputs ),
982
+ strict = True ,
981
983
)
982
984
):
983
985
outer_mitmot = copy_var_format (_outer_mitmot , as_var = inner_mitmot [ipos ])
@@ -1032,6 +1034,7 @@ def make_node(self, *inputs):
1032
1034
self .info .mit_sot_in_slices ,
1033
1035
self .outer_mitsot (inputs ),
1034
1036
self .inner_mitsot_outs (self .inner_outputs ),
1037
+ strict = True ,
1035
1038
)
1036
1039
):
1037
1040
outer_mitsot = copy_var_format (_outer_mitsot , as_var = inner_mitsots [ipos ])
@@ -1084,6 +1087,7 @@ def make_node(self, *inputs):
1084
1087
self .inner_sitsot (self .inner_inputs ),
1085
1088
self .outer_sitsot (inputs ),
1086
1089
self .inner_sitsot_outs (self .inner_outputs ),
1090
+ strict = True ,
1087
1091
)
1088
1092
):
1089
1093
outer_sitsot = copy_var_format (_outer_sitsot , as_var = inner_sitsot )
@@ -1131,6 +1135,7 @@ def make_node(self, *inputs):
1131
1135
self .inner_shared (self .inner_inputs ),
1132
1136
self .inner_shared_outs (self .inner_outputs ),
1133
1137
self .outer_shared (inputs ),
1138
+ strict = True ,
1134
1139
)
1135
1140
):
1136
1141
outer_shared = copy_var_format (_outer_shared , as_var = inner_shared )
@@ -1189,7 +1194,9 @@ def make_node(self, *inputs):
1189
1194
# type of tensor as the output, it is always a scalar int.
1190
1195
new_inputs += [as_tensor_variable (ons ) for ons in self .outer_nitsot (inputs )]
1191
1196
for inner_nonseq , _outer_nonseq in zip (
1192
- self .inner_non_seqs (self .inner_inputs ), self .outer_non_seqs (inputs )
1197
+ self .inner_non_seqs (self .inner_inputs ),
1198
+ self .outer_non_seqs (inputs ),
1199
+ strict = True ,
1193
1200
):
1194
1201
outer_nonseq = copy_var_format (_outer_nonseq , as_var = inner_nonseq )
1195
1202
new_inputs .append (outer_nonseq )
@@ -1272,7 +1279,9 @@ def __eq__(self, other):
1272
1279
if len (self .inner_outputs ) != len (other .inner_outputs ):
1273
1280
return False
1274
1281
1275
- for self_in , other_in in zip (self .inner_inputs , other .inner_inputs ):
1282
+ for self_in , other_in in zip (
1283
+ self .inner_inputs , other .inner_inputs , strict = True
1284
+ ):
1276
1285
if self_in .type != other_in .type :
1277
1286
return False
1278
1287
@@ -1407,7 +1416,7 @@ def prepare_fgraph(self, fgraph):
1407
1416
fgraph .attach_feature (
1408
1417
Supervisor (
1409
1418
inp
1410
- for spec , inp in zip (wrapped_inputs , fgraph .inputs )
1419
+ for spec , inp in zip (wrapped_inputs , fgraph .inputs , strict = True )
1411
1420
if not (
1412
1421
getattr (spec , "mutable" , None )
1413
1422
or (hasattr (fgraph , "destroyers" ) and fgraph .has_destroyers ([inp ]))
@@ -2087,7 +2096,9 @@ def perform(self, node, inputs, output_storage):
2087
2096
jout = j + offset_out
2088
2097
output_storage [j ][0 ] = inner_output_storage [jout ].storage [0 ]
2089
2098
2090
- pos = [(idx + 1 ) % store for idx , store in zip (pos , store_steps )]
2099
+ pos = [
2100
+ (idx + 1 ) % store for idx , store in zip (pos , store_steps , strict = True )
2101
+ ]
2091
2102
i = i + 1
2092
2103
2093
2104
# 6. Check if you need to re-order output buffers
@@ -2172,7 +2183,7 @@ def perform(self, node, inputs, output_storage):
2172
2183
2173
2184
def infer_shape (self , fgraph , node , input_shapes ):
2174
2185
# input_shapes correspond to the shapes of node.inputs
2175
- for inp , inp_shp in zip (node .inputs , input_shapes ):
2186
+ for inp , inp_shp in zip (node .inputs , input_shapes , strict = True ):
2176
2187
assert inp_shp is None or len (inp_shp ) == inp .type .ndim
2177
2188
2178
2189
# Here we build 2 variables;
@@ -2241,7 +2252,9 @@ def infer_shape(self, fgraph, node, input_shapes):
2241
2252
# Non-sequences have a direct equivalent from self.inner_inputs in
2242
2253
# node.inputs
2243
2254
inner_non_sequences = self .inner_inputs [len (seqs_shape ) + len (outs_shape ) :]
2244
- for in_ns , out_ns in zip (inner_non_sequences , node .inputs [offset :]):
2255
+ for in_ns , out_ns in zip (
2256
+ inner_non_sequences , node .inputs [offset :], strict = True
2257
+ ):
2245
2258
out_equivalent [in_ns ] = out_ns
2246
2259
2247
2260
if info .as_while :
@@ -2276,7 +2289,7 @@ def infer_shape(self, fgraph, node, input_shapes):
2276
2289
r = node .outputs [n_outs + x ]
2277
2290
assert r .ndim == 1 + len (out_shape_x )
2278
2291
shp = [node .inputs [offset + info .n_shared_outs + x ]]
2279
- for i , shp_i in zip (range (1 , r .ndim ), out_shape_x ):
2292
+ for i , shp_i in zip (range (1 , r .ndim ), out_shape_x , strict = True ):
2280
2293
# Validate shp_i. v_shape_i is either None (if invalid),
2281
2294
# or a (variable, Boolean) tuple. The Boolean indicates
2282
2295
# whether variable is shp_i (if True), or an valid
@@ -2298,7 +2311,7 @@ def infer_shape(self, fgraph, node, input_shapes):
2298
2311
if info .as_while :
2299
2312
scan_outs_init = scan_outs
2300
2313
scan_outs = []
2301
- for o , x in zip (node .outputs , scan_outs_init ):
2314
+ for o , x in zip (node .outputs , scan_outs_init , strict = True ):
2302
2315
if x is None :
2303
2316
scan_outs .append (None )
2304
2317
else :
@@ -2574,7 +2587,9 @@ def compute_all_gradients(known_grads):
2574
2587
dC_dinps_t [dx ] = pt .zeros_like (diff_inputs [dx ])
2575
2588
else :
2576
2589
disconnected_dC_dinps_t [dx ] = False
2577
- for Xt , Xt_placeholder in zip (diff_outputs [info .n_mit_mot_outs :], Xts ):
2590
+ for Xt , Xt_placeholder in zip (
2591
+ diff_outputs [info .n_mit_mot_outs :], Xts , strict = True
2592
+ ):
2578
2593
tmp = forced_replace (dC_dinps_t [dx ], Xt , Xt_placeholder )
2579
2594
dC_dinps_t [dx ] = tmp
2580
2595
@@ -2654,7 +2669,9 @@ def compute_all_gradients(known_grads):
2654
2669
n = n_steps .tag .test_value
2655
2670
else :
2656
2671
n = inputs [0 ].tag .test_value
2657
- for taps , x in zip (info .mit_sot_in_slices , self .outer_mitsot_outs (outs )):
2672
+ for taps , x in zip (
2673
+ info .mit_sot_in_slices , self .outer_mitsot_outs (outs ), strict = True
2674
+ ):
2658
2675
mintap = np .min (taps )
2659
2676
if hasattr (x [::- 1 ][:mintap ], "test_value" ):
2660
2677
assert x [::- 1 ][:mintap ].tag .test_value .shape [0 ] == n
@@ -2669,7 +2686,9 @@ def compute_all_gradients(known_grads):
2669
2686
assert x [::- 1 ].tag .test_value .shape [0 ] == n
2670
2687
outer_inp_seqs += [
2671
2688
x [::- 1 ][: np .min (taps )]
2672
- for taps , x in zip (info .mit_sot_in_slices , self .outer_mitsot_outs (outs ))
2689
+ for taps , x in zip (
2690
+ info .mit_sot_in_slices , self .outer_mitsot_outs (outs ), strict = True
2691
+ )
2673
2692
]
2674
2693
outer_inp_seqs += [x [::- 1 ][:- 1 ] for x in self .outer_sitsot_outs (outs )]
2675
2694
outer_inp_seqs += [x [::- 1 ] for x in self .outer_nitsot_outs (outs )]
@@ -3000,6 +3019,7 @@ def compute_all_gradients(known_grads):
3000
3019
zip (
3001
3020
outputs [offset : offset + info .n_seqs ],
3002
3021
type_outs [offset : offset + info .n_seqs ],
3022
+ strict = True ,
3003
3023
)
3004
3024
):
3005
3025
if t == "connected" :
@@ -3029,7 +3049,7 @@ def compute_all_gradients(known_grads):
3029
3049
gradients .append (NullType (t )())
3030
3050
3031
3051
end = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
3032
- for p , (x , t ) in enumerate (zip (outputs [:end ], type_outs [:end ])):
3052
+ for p , (x , t ) in enumerate (zip (outputs [:end ], type_outs [:end ], strict = True )):
3033
3053
if t == "connected" :
3034
3054
# If the forward scan is in as_while mode, we need to pad
3035
3055
# the gradients, so that they match the size of the input
@@ -3064,7 +3084,7 @@ def compute_all_gradients(known_grads):
3064
3084
for idx in range (info .n_shared_outs ):
3065
3085
disconnected = True
3066
3086
connected_flags = self .connection_pattern (node )[idx + start ]
3067
- for dC_dout , connected in zip (dC_douts , connected_flags ):
3087
+ for dC_dout , connected in zip (dC_douts , connected_flags , strict = True ):
3068
3088
if not isinstance (dC_dout .type , DisconnectedType ) and connected :
3069
3089
disconnected = False
3070
3090
if disconnected :
@@ -3081,7 +3101,9 @@ def compute_all_gradients(known_grads):
3081
3101
begin = end
3082
3102
3083
3103
end = begin + n_sitsot_outs
3084
- for p , (x , t ) in enumerate (zip (outputs [begin :end ], type_outs [begin :end ])):
3104
+ for p , (x , t ) in enumerate (
3105
+ zip (outputs [begin :end ], type_outs [begin :end ], strict = True )
3106
+ ):
3085
3107
if t == "connected" :
3086
3108
gradients .append (x [- 1 ])
3087
3109
elif t == "disconnected" :
@@ -3158,7 +3180,7 @@ def R_op(self, inputs, eval_points):
3158
3180
e = 1 + info .n_seqs
3159
3181
ie = info .n_seqs
3160
3182
clean_eval_points = []
3161
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3183
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3162
3184
if evp is not None :
3163
3185
clean_eval_points .append (evp )
3164
3186
else :
@@ -3173,7 +3195,7 @@ def R_op(self, inputs, eval_points):
3173
3195
ib = ie
3174
3196
ie = ie + int (sum (len (x ) for x in info .mit_mot_in_slices ))
3175
3197
clean_eval_points = []
3176
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3198
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3177
3199
if evp is not None :
3178
3200
clean_eval_points .append (evp )
3179
3201
else :
@@ -3188,7 +3210,7 @@ def R_op(self, inputs, eval_points):
3188
3210
ib = ie
3189
3211
ie = ie + int (sum (len (x ) for x in info .mit_sot_in_slices ))
3190
3212
clean_eval_points = []
3191
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3213
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3192
3214
if evp is not None :
3193
3215
clean_eval_points .append (evp )
3194
3216
else :
@@ -3203,7 +3225,7 @@ def R_op(self, inputs, eval_points):
3203
3225
ib = ie
3204
3226
ie = ie + info .n_sit_sot
3205
3227
clean_eval_points = []
3206
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3228
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3207
3229
if evp is not None :
3208
3230
clean_eval_points .append (evp )
3209
3231
else :
@@ -3227,7 +3249,7 @@ def R_op(self, inputs, eval_points):
3227
3249
3228
3250
# All other arguments
3229
3251
clean_eval_points = []
3230
- for inp , evp in zip (inputs [e :], eval_points [e :]):
3252
+ for inp , evp in zip (inputs [e :], eval_points [e :], strict = True ):
3231
3253
if evp is not None :
3232
3254
clean_eval_points .append (evp )
3233
3255
else :
0 commit comments