@@ -170,7 +170,7 @@ def check_broadcast(v1, v2):
170
170
)
171
171
size = min (v1 .type .ndim , v2 .type .ndim )
172
172
for n , (b1 , b2 ) in enumerate (
173
- zip (v1 .type .broadcastable [- size :], v2 .type .broadcastable [- size :])
173
+ zip (v1 .type .broadcastable [- size :], v2 .type .broadcastable [- size :], strict = False )
174
174
):
175
175
if b1 != b2 :
176
176
a1 = n + size - v1 .type .ndim + 1
@@ -577,6 +577,7 @@ def get_oinp_iinp_iout_oout_mappings(self):
577
577
inner_input_indices ,
578
578
inner_output_indices ,
579
579
outer_output_indices ,
580
+ strict = True ,
580
581
):
581
582
if oout != - 1 :
582
583
mappings ["outer_inp_from_outer_out" ][oout ] = oinp
@@ -958,7 +959,7 @@ def make_node(self, *inputs):
958
959
# them have the same dtype
959
960
argoffset = 0
960
961
for inner_seq , outer_seq in zip (
961
- self .inner_seqs (self .inner_inputs ), self .outer_seqs (inputs )
962
+ self .inner_seqs (self .inner_inputs ), self .outer_seqs (inputs ), strict = True
962
963
):
963
964
check_broadcast (outer_seq , inner_seq )
964
965
new_inputs .append (copy_var_format (outer_seq , as_var = inner_seq ))
@@ -977,6 +978,7 @@ def make_node(self, *inputs):
977
978
self .info .mit_mot_in_slices ,
978
979
self .info .mit_mot_out_slices [: self .info .n_mit_mot ],
979
980
self .outer_mitmot (inputs ),
981
+ strict = True ,
980
982
)
981
983
):
982
984
outer_mitmot = copy_var_format (_outer_mitmot , as_var = inner_mitmot [ipos ])
@@ -1031,6 +1033,7 @@ def make_node(self, *inputs):
1031
1033
self .info .mit_sot_in_slices ,
1032
1034
self .outer_mitsot (inputs ),
1033
1035
self .inner_mitsot_outs (self .inner_outputs ),
1036
+ strict = True ,
1034
1037
)
1035
1038
):
1036
1039
outer_mitsot = copy_var_format (_outer_mitsot , as_var = inner_mitsots [ipos ])
@@ -1083,6 +1086,7 @@ def make_node(self, *inputs):
1083
1086
self .inner_sitsot (self .inner_inputs ),
1084
1087
self .outer_sitsot (inputs ),
1085
1088
self .inner_sitsot_outs (self .inner_outputs ),
1089
+ strict = True ,
1086
1090
)
1087
1091
):
1088
1092
outer_sitsot = copy_var_format (_outer_sitsot , as_var = inner_sitsot )
@@ -1130,6 +1134,7 @@ def make_node(self, *inputs):
1130
1134
self .inner_shared (self .inner_inputs ),
1131
1135
self .inner_shared_outs (self .inner_outputs ),
1132
1136
self .outer_shared (inputs ),
1137
+ strict = True ,
1133
1138
)
1134
1139
):
1135
1140
outer_shared = copy_var_format (_outer_shared , as_var = inner_shared )
@@ -1188,7 +1193,9 @@ def make_node(self, *inputs):
1188
1193
# type of tensor as the output, it is always a scalar int.
1189
1194
new_inputs += [as_tensor_variable (ons ) for ons in self .outer_nitsot (inputs )]
1190
1195
for inner_nonseq , _outer_nonseq in zip (
1191
- self .inner_non_seqs (self .inner_inputs ), self .outer_non_seqs (inputs )
1196
+ self .inner_non_seqs (self .inner_inputs ),
1197
+ self .outer_non_seqs (inputs ),
1198
+ strict = True ,
1192
1199
):
1193
1200
outer_nonseq = copy_var_format (_outer_nonseq , as_var = inner_nonseq )
1194
1201
new_inputs .append (outer_nonseq )
@@ -1271,7 +1278,9 @@ def __eq__(self, other):
1271
1278
if len (self .inner_outputs ) != len (other .inner_outputs ):
1272
1279
return False
1273
1280
1274
- for self_in , other_in in zip (self .inner_inputs , other .inner_inputs ):
1281
+ for self_in , other_in in zip (
1282
+ self .inner_inputs , other .inner_inputs , strict = True
1283
+ ):
1275
1284
if self_in .type != other_in .type :
1276
1285
return False
1277
1286
@@ -1406,7 +1415,7 @@ def prepare_fgraph(self, fgraph):
1406
1415
fgraph .attach_feature (
1407
1416
Supervisor (
1408
1417
inp
1409
- for spec , inp in zip (wrapped_inputs , fgraph .inputs )
1418
+ for spec , inp in zip (wrapped_inputs , fgraph .inputs , strict = True )
1410
1419
if not (
1411
1420
getattr (spec , "mutable" , None )
1412
1421
or (hasattr (fgraph , "destroyers" ) and fgraph .has_destroyers ([inp ]))
@@ -2086,7 +2095,9 @@ def perform(self, node, inputs, output_storage):
2086
2095
jout = j + offset_out
2087
2096
output_storage [j ][0 ] = inner_output_storage [jout ].storage [0 ]
2088
2097
2089
- pos = [(idx + 1 ) % store for idx , store in zip (pos , store_steps )]
2098
+ pos = [
2099
+ (idx + 1 ) % store for idx , store in zip (pos , store_steps , strict = True )
2100
+ ]
2090
2101
i = i + 1
2091
2102
2092
2103
# 6. Check if you need to re-order output buffers
@@ -2171,7 +2182,7 @@ def perform(self, node, inputs, output_storage):
2171
2182
2172
2183
def infer_shape (self , fgraph , node , input_shapes ):
2173
2184
# input_shapes correspond to the shapes of node.inputs
2174
- for inp , inp_shp in zip (node .inputs , input_shapes ):
2185
+ for inp , inp_shp in zip (node .inputs , input_shapes , strict = True ):
2175
2186
assert inp_shp is None or len (inp_shp ) == inp .type .ndim
2176
2187
2177
2188
# Here we build 2 variables;
@@ -2240,7 +2251,9 @@ def infer_shape(self, fgraph, node, input_shapes):
2240
2251
# Non-sequences have a direct equivalent from self.inner_inputs in
2241
2252
# node.inputs
2242
2253
inner_non_sequences = self .inner_inputs [len (seqs_shape ) + len (outs_shape ) :]
2243
- for in_ns , out_ns in zip (inner_non_sequences , node .inputs [offset :]):
2254
+ for in_ns , out_ns in zip (
2255
+ inner_non_sequences , node .inputs [offset :], strict = True
2256
+ ):
2244
2257
out_equivalent [in_ns ] = out_ns
2245
2258
2246
2259
if info .as_while :
@@ -2275,7 +2288,7 @@ def infer_shape(self, fgraph, node, input_shapes):
2275
2288
r = node .outputs [n_outs + x ]
2276
2289
assert r .ndim == 1 + len (out_shape_x )
2277
2290
shp = [node .inputs [offset + info .n_shared_outs + x ]]
2278
- for i , shp_i in zip (range (1 , r .ndim ), out_shape_x ):
2291
+ for i , shp_i in zip (range (1 , r .ndim ), out_shape_x , strict = True ):
2279
2292
# Validate shp_i. v_shape_i is either None (if invalid),
2280
2293
# or a (variable, Boolean) tuple. The Boolean indicates
2281
2294
# whether variable is shp_i (if True), or an valid
@@ -2297,7 +2310,7 @@ def infer_shape(self, fgraph, node, input_shapes):
2297
2310
if info .as_while :
2298
2311
scan_outs_init = scan_outs
2299
2312
scan_outs = []
2300
- for o , x in zip (node .outputs , scan_outs_init ):
2313
+ for o , x in zip (node .outputs , scan_outs_init , strict = True ):
2301
2314
if x is None :
2302
2315
scan_outs .append (None )
2303
2316
else :
@@ -2573,7 +2586,9 @@ def compute_all_gradients(known_grads):
2573
2586
dC_dinps_t [dx ] = pt .zeros_like (diff_inputs [dx ])
2574
2587
else :
2575
2588
disconnected_dC_dinps_t [dx ] = False
2576
- for Xt , Xt_placeholder in zip (diff_outputs [info .n_mit_mot_outs :], Xts ):
2589
+ for Xt , Xt_placeholder in zip (
2590
+ diff_outputs [info .n_mit_mot_outs :], Xts , strict = True
2591
+ ):
2577
2592
tmp = forced_replace (dC_dinps_t [dx ], Xt , Xt_placeholder )
2578
2593
dC_dinps_t [dx ] = tmp
2579
2594
@@ -2653,7 +2668,9 @@ def compute_all_gradients(known_grads):
2653
2668
n = n_steps .tag .test_value
2654
2669
else :
2655
2670
n = inputs [0 ].tag .test_value
2656
- for taps , x in zip (info .mit_sot_in_slices , self .outer_mitsot_outs (outs )):
2671
+ for taps , x in zip (
2672
+ info .mit_sot_in_slices , self .outer_mitsot_outs (outs ), strict = True
2673
+ ):
2657
2674
mintap = np .min (taps )
2658
2675
if hasattr (x [::- 1 ][:mintap ], "test_value" ):
2659
2676
assert x [::- 1 ][:mintap ].tag .test_value .shape [0 ] == n
@@ -2668,7 +2685,9 @@ def compute_all_gradients(known_grads):
2668
2685
assert x [::- 1 ].tag .test_value .shape [0 ] == n
2669
2686
outer_inp_seqs += [
2670
2687
x [::- 1 ][: np .min (taps )]
2671
- for taps , x in zip (info .mit_sot_in_slices , self .outer_mitsot_outs (outs ))
2688
+ for taps , x in zip (
2689
+ info .mit_sot_in_slices , self .outer_mitsot_outs (outs ), strict = True
2690
+ )
2672
2691
]
2673
2692
outer_inp_seqs += [x [::- 1 ][:- 1 ] for x in self .outer_sitsot_outs (outs )]
2674
2693
outer_inp_seqs += [x [::- 1 ] for x in self .outer_nitsot_outs (outs )]
@@ -2999,6 +3018,7 @@ def compute_all_gradients(known_grads):
2999
3018
zip (
3000
3019
outputs [offset : offset + info .n_seqs ],
3001
3020
type_outs [offset : offset + info .n_seqs ],
3021
+ strict = True ,
3002
3022
)
3003
3023
):
3004
3024
if t == "connected" :
@@ -3028,7 +3048,7 @@ def compute_all_gradients(known_grads):
3028
3048
gradients .append (NullType (t )())
3029
3049
3030
3050
end = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
3031
- for p , (x , t ) in enumerate (zip (outputs [:end ], type_outs [:end ])):
3051
+ for p , (x , t ) in enumerate (zip (outputs [:end ], type_outs [:end ], strict = True )):
3032
3052
if t == "connected" :
3033
3053
# If the forward scan is in as_while mode, we need to pad
3034
3054
# the gradients, so that they match the size of the input
@@ -3063,7 +3083,7 @@ def compute_all_gradients(known_grads):
3063
3083
for idx in range (info .n_shared_outs ):
3064
3084
disconnected = True
3065
3085
connected_flags = self .connection_pattern (node )[idx + start ]
3066
- for dC_dout , connected in zip (dC_douts , connected_flags ):
3086
+ for dC_dout , connected in zip (dC_douts , connected_flags , strict = True ):
3067
3087
if not isinstance (dC_dout .type , DisconnectedType ) and connected :
3068
3088
disconnected = False
3069
3089
if disconnected :
@@ -3080,7 +3100,9 @@ def compute_all_gradients(known_grads):
3080
3100
begin = end
3081
3101
3082
3102
end = begin + n_sitsot_outs
3083
- for p , (x , t ) in enumerate (zip (outputs [begin :end ], type_outs [begin :end ])):
3103
+ for p , (x , t ) in enumerate (
3104
+ zip (outputs [begin :end ], type_outs [begin :end ], strict = True )
3105
+ ):
3084
3106
if t == "connected" :
3085
3107
gradients .append (x [- 1 ])
3086
3108
elif t == "disconnected" :
@@ -3157,7 +3179,7 @@ def R_op(self, inputs, eval_points):
3157
3179
e = 1 + info .n_seqs
3158
3180
ie = info .n_seqs
3159
3181
clean_eval_points = []
3160
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3182
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3161
3183
if evp is not None :
3162
3184
clean_eval_points .append (evp )
3163
3185
else :
@@ -3172,7 +3194,7 @@ def R_op(self, inputs, eval_points):
3172
3194
ib = ie
3173
3195
ie = ie + int (sum (len (x ) for x in info .mit_mot_in_slices ))
3174
3196
clean_eval_points = []
3175
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3197
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3176
3198
if evp is not None :
3177
3199
clean_eval_points .append (evp )
3178
3200
else :
@@ -3187,7 +3209,7 @@ def R_op(self, inputs, eval_points):
3187
3209
ib = ie
3188
3210
ie = ie + int (sum (len (x ) for x in info .mit_sot_in_slices ))
3189
3211
clean_eval_points = []
3190
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3212
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3191
3213
if evp is not None :
3192
3214
clean_eval_points .append (evp )
3193
3215
else :
@@ -3202,7 +3224,7 @@ def R_op(self, inputs, eval_points):
3202
3224
ib = ie
3203
3225
ie = ie + info .n_sit_sot
3204
3226
clean_eval_points = []
3205
- for inp , evp in zip (inputs [b :e ], eval_points [b :e ]):
3227
+ for inp , evp in zip (inputs [b :e ], eval_points [b :e ], strict = True ):
3206
3228
if evp is not None :
3207
3229
clean_eval_points .append (evp )
3208
3230
else :
@@ -3226,7 +3248,7 @@ def R_op(self, inputs, eval_points):
3226
3248
3227
3249
# All other arguments
3228
3250
clean_eval_points = []
3229
- for inp , evp in zip (inputs [e :], eval_points [e :]):
3251
+ for inp , evp in zip (inputs [e :], eval_points [e :], strict = True ):
3230
3252
if evp is not None :
3231
3253
clean_eval_points .append (evp )
3232
3254
else :
0 commit comments