@@ -88,7 +88,7 @@ def __getstate__(self):
88
88
89
89
def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
90
90
core_input_types = []
91
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
91
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
92
92
if inp .type .ndim < len (sig ):
93
93
raise ValueError (
94
94
f"Input { i } { inp } has insufficient core dimensions for signature { self .signature } "
@@ -106,7 +106,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
106
106
raise ValueError (
107
107
f"Insufficient number of outputs for signature { self .signature } : { len (core_node .outputs )} "
108
108
)
109
- for i , (core_out , sig ) in enumerate (zip (core_node .outputs , self .outputs_sig )):
109
+ for i , (core_out , sig ) in enumerate (
110
+ zip (core_node .outputs , self .outputs_sig , strict = True )
111
+ ):
110
112
if core_out .type .ndim != len (sig ):
111
113
raise ValueError (
112
114
f"Output { i } of { self .core_op } has wrong number of core dimensions for signature { self .signature } : { core_out .type .ndim } "
@@ -120,12 +122,13 @@ def make_node(self, *inputs):
120
122
core_node = self ._create_dummy_core_node (inputs )
121
123
122
124
batch_ndims = max (
123
- inp .type .ndim - len (sig ) for inp , sig in zip (inputs , self .inputs_sig )
125
+ inp .type .ndim - len (sig )
126
+ for inp , sig in zip (inputs , self .inputs_sig , strict = True )
124
127
)
125
128
126
129
batched_inputs = []
127
130
batch_shapes = []
128
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
131
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
129
132
# Append missing dims to the left
130
133
missing_batch_ndims = batch_ndims - (inp .type .ndim - len (sig ))
131
134
if missing_batch_ndims :
@@ -141,7 +144,7 @@ def make_node(self, *inputs):
141
144
batch_shape = tuple (
142
145
[
143
146
broadcast_static_dim_lengths (batch_dims )
144
- for batch_dims in zip (* batch_shapes )
147
+ for batch_dims in zip (* batch_shapes , strict = True )
145
148
]
146
149
)
147
150
except ValueError :
@@ -168,11 +171,11 @@ def infer_shape(
168
171
batch_ndims = self .batch_ndim (node )
169
172
core_dims : dict [str , Any ] = {}
170
173
batch_shapes = []
171
- for input_shape , sig in zip (input_shapes , self .inputs_sig ):
174
+ for input_shape , sig in zip (input_shapes , self .inputs_sig , strict = True ):
172
175
batch_shapes .append (input_shape [:batch_ndims ])
173
176
core_shape = input_shape [batch_ndims :]
174
177
175
- for core_dim , dim_name in zip (core_shape , sig ):
178
+ for core_dim , dim_name in zip (core_shape , sig , strict = True ):
176
179
prev_core_dim = core_dims .get (core_dim )
177
180
if prev_core_dim is None :
178
181
core_dims [dim_name ] = core_dim
@@ -183,7 +186,7 @@ def infer_shape(
183
186
batch_shape = broadcast_shape (* batch_shapes , arrays_are_shapes = True )
184
187
185
188
out_shapes = []
186
- for output , sig in zip (node .outputs , self .outputs_sig ):
189
+ for output , sig in zip (node .outputs , self .outputs_sig , strict = True ):
187
190
core_out_shape = []
188
191
for i , dim_name in enumerate (sig ):
189
192
# The output dim is the same as another input dim
@@ -214,17 +217,17 @@ def as_core(t, core_t):
214
217
with config .change_flags (compute_test_value = "off" ):
215
218
safe_inputs = [
216
219
tensor (dtype = inp .type .dtype , shape = (None ,) * len (sig ))
217
- for inp , sig in zip (inputs , self .inputs_sig )
220
+ for inp , sig in zip (inputs , self .inputs_sig , strict = True )
218
221
]
219
222
core_node = self ._create_dummy_core_node (safe_inputs )
220
223
221
224
core_inputs = [
222
225
as_core (inp , core_inp )
223
- for inp , core_inp in zip (inputs , core_node .inputs )
226
+ for inp , core_inp in zip (inputs , core_node .inputs , strict = True )
224
227
]
225
228
core_ograds = [
226
229
as_core (ograd , core_ograd )
227
- for ograd , core_ograd in zip (ograds , core_node .outputs )
230
+ for ograd , core_ograd in zip (ograds , core_node .outputs , strict = True )
228
231
]
229
232
core_outputs = core_node .outputs
230
233
@@ -233,7 +236,11 @@ def as_core(t, core_t):
233
236
igrads = vectorize_graph (
234
237
[core_igrad for core_igrad in core_igrads if core_igrad is not None ],
235
238
replace = dict (
236
- zip (core_inputs + core_outputs + core_ograds , inputs + outputs + ograds )
239
+ zip (
240
+ core_inputs + core_outputs + core_ograds ,
241
+ inputs + outputs + ograds ,
242
+ strict = True ,
243
+ )
237
244
),
238
245
)
239
246
@@ -259,7 +266,7 @@ def L_op(self, inputs, outs, ograds):
259
266
# the return value obviously zero so that gradient.grad can tell
260
267
# this op did the right thing.
261
268
new_rval = []
262
- for elem , inp in zip (rval , inputs ):
269
+ for elem , inp in zip (rval , inputs , strict = True ):
263
270
if isinstance (elem .type , NullType | DisconnectedType ):
264
271
new_rval .append (elem )
265
272
else :
@@ -273,15 +280,17 @@ def L_op(self, inputs, outs, ograds):
273
280
# Sum out the broadcasted dimensions
274
281
batch_ndims = self .batch_ndim (outs [0 ].owner )
275
282
batch_shape = outs [0 ].type .shape [:batch_ndims ]
276
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
283
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
277
284
if isinstance (rval [i ].type , NullType | DisconnectedType ):
278
285
continue
279
286
280
287
assert inp .type .ndim == batch_ndims + len (sig )
281
288
282
289
to_sum = [
283
290
j
284
- for j , (inp_s , out_s ) in enumerate (zip (inp .type .shape , batch_shape ))
291
+ for j , (inp_s , out_s ) in enumerate (
292
+ zip (inp .type .shape , batch_shape , strict = True )
293
+ )
285
294
if inp_s == 1 and out_s != 1
286
295
]
287
296
if to_sum :
@@ -321,9 +330,14 @@ def _check_runtime_broadcast(self, node, inputs):
321
330
322
331
for dims_and_bcast in zip (
323
332
* [
324
- zip (input .shape [:batch_ndim ], sinput .type .broadcastable [:batch_ndim ])
325
- for input , sinput in zip (inputs , node .inputs )
326
- ]
333
+ zip (
334
+ input .shape [:batch_ndim ],
335
+ sinput .type .broadcastable [:batch_ndim ],
336
+ strict = True ,
337
+ )
338
+ for input , sinput in zip (inputs , node .inputs , strict = True )
339
+ ],
340
+ strict = True ,
327
341
):
328
342
if any (d != 1 for d , _ in dims_and_bcast ) and (1 , False ) in dims_and_bcast :
329
343
raise ValueError (
@@ -344,7 +358,9 @@ def perform(self, node, inputs, output_storage):
344
358
if not isinstance (res , tuple ):
345
359
res = (res ,)
346
360
347
- for node_out , out_storage , r in zip (node .outputs , output_storage , res ):
361
+ for node_out , out_storage , r in zip (
362
+ node .outputs , output_storage , res , strict = True
363
+ ):
348
364
out_dtype = getattr (node_out , "dtype" , None )
349
365
if out_dtype and out_dtype != r .dtype :
350
366
r = np .asarray (r , dtype = out_dtype )
0 commit comments