@@ -88,7 +88,7 @@ def __getstate__(self):
88
88
89
89
def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
90
90
core_input_types = []
91
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
91
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
92
92
if inp .type .ndim < len (sig ):
93
93
raise ValueError (
94
94
f"Input { i } { inp } has insufficient core dimensions for signature { self .signature } "
@@ -106,7 +106,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
106
106
raise ValueError (
107
107
f"Insufficient number of outputs for signature { self .signature } : { len (core_node .outputs )} "
108
108
)
109
- for i , (core_out , sig ) in enumerate (zip (core_node .outputs , self .outputs_sig )):
109
+ for i , (core_out , sig ) in enumerate (
110
+ zip (core_node .outputs , self .outputs_sig , strict = True )
111
+ ):
110
112
if core_out .type .ndim != len (sig ):
111
113
raise ValueError (
112
114
f"Output { i } of { self .core_op } has wrong number of core dimensions for signature { self .signature } : { core_out .type .ndim } "
@@ -120,12 +122,13 @@ def make_node(self, *inputs):
120
122
core_node = self ._create_dummy_core_node (inputs )
121
123
122
124
batch_ndims = max (
123
- inp .type .ndim - len (sig ) for inp , sig in zip (inputs , self .inputs_sig )
125
+ inp .type .ndim - len (sig )
126
+ for inp , sig in zip (inputs , self .inputs_sig , strict = True )
124
127
)
125
128
126
129
batched_inputs = []
127
130
batch_shapes = []
128
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
131
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
129
132
# Append missing dims to the left
130
133
missing_batch_ndims = batch_ndims - (inp .type .ndim - len (sig ))
131
134
if missing_batch_ndims :
@@ -141,7 +144,7 @@ def make_node(self, *inputs):
141
144
batch_shape = tuple (
142
145
[
143
146
broadcast_static_dim_lengths (batch_dims )
144
- for batch_dims in zip (* batch_shapes )
147
+ for batch_dims in zip (* batch_shapes , strict = True )
145
148
]
146
149
)
147
150
except ValueError :
@@ -168,10 +171,10 @@ def infer_shape(
168
171
batch_ndims = self .batch_ndim (node )
169
172
core_dims : dict [str , Any ] = {}
170
173
batch_shapes = [input_shape [:batch_ndims ] for input_shape in input_shapes ]
171
- for input_shape , sig in zip (input_shapes , self .inputs_sig ):
174
+ for input_shape , sig in zip (input_shapes , self .inputs_sig , strict = True ):
172
175
core_shape = input_shape [batch_ndims :]
173
176
174
- for core_dim , dim_name in zip (core_shape , sig ):
177
+ for core_dim , dim_name in zip (core_shape , sig , strict = True ):
175
178
prev_core_dim = core_dims .get (core_dim )
176
179
if prev_core_dim is None :
177
180
core_dims [dim_name ] = core_dim
@@ -182,7 +185,7 @@ def infer_shape(
182
185
batch_shape = broadcast_shape (* batch_shapes , arrays_are_shapes = True )
183
186
184
187
out_shapes = []
185
- for output , sig in zip (node .outputs , self .outputs_sig ):
188
+ for output , sig in zip (node .outputs , self .outputs_sig , strict = True ):
186
189
core_out_shape = []
187
190
for i , dim_name in enumerate (sig ):
188
191
# The output dim is the same as another input dim
@@ -213,17 +216,17 @@ def as_core(t, core_t):
213
216
with config .change_flags (compute_test_value = "off" ):
214
217
safe_inputs = [
215
218
tensor (dtype = inp .type .dtype , shape = (None ,) * len (sig ))
216
- for inp , sig in zip (inputs , self .inputs_sig )
219
+ for inp , sig in zip (inputs , self .inputs_sig , strict = True )
217
220
]
218
221
core_node = self ._create_dummy_core_node (safe_inputs )
219
222
220
223
core_inputs = [
221
224
as_core (inp , core_inp )
222
- for inp , core_inp in zip (inputs , core_node .inputs )
225
+ for inp , core_inp in zip (inputs , core_node .inputs , strict = True )
223
226
]
224
227
core_ograds = [
225
228
as_core (ograd , core_ograd )
226
- for ograd , core_ograd in zip (ograds , core_node .outputs )
229
+ for ograd , core_ograd in zip (ograds , core_node .outputs , strict = True )
227
230
]
228
231
core_outputs = core_node .outputs
229
232
@@ -232,7 +235,11 @@ def as_core(t, core_t):
232
235
igrads = vectorize_graph (
233
236
[core_igrad for core_igrad in core_igrads if core_igrad is not None ],
234
237
replace = dict (
235
- zip (core_inputs + core_outputs + core_ograds , inputs + outputs + ograds )
238
+ zip (
239
+ core_inputs + core_outputs + core_ograds ,
240
+ inputs + outputs + ograds ,
241
+ strict = True ,
242
+ )
236
243
),
237
244
)
238
245
@@ -258,7 +265,7 @@ def L_op(self, inputs, outs, ograds):
258
265
# the return value obviously zero so that gradient.grad can tell
259
266
# this op did the right thing.
260
267
new_rval = []
261
- for elem , inp in zip (rval , inputs ):
268
+ for elem , inp in zip (rval , inputs , strict = True ):
262
269
if isinstance (elem .type , NullType | DisconnectedType ):
263
270
new_rval .append (elem )
264
271
else :
@@ -272,15 +279,17 @@ def L_op(self, inputs, outs, ograds):
272
279
# Sum out the broadcasted dimensions
273
280
batch_ndims = self .batch_ndim (outs [0 ].owner )
274
281
batch_shape = outs [0 ].type .shape [:batch_ndims ]
275
- for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
282
+ for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
276
283
if isinstance (rval [i ].type , NullType | DisconnectedType ):
277
284
continue
278
285
279
286
assert inp .type .ndim == batch_ndims + len (sig )
280
287
281
288
to_sum = [
282
289
j
283
- for j , (inp_s , out_s ) in enumerate (zip (inp .type .shape , batch_shape ))
290
+ for j , (inp_s , out_s ) in enumerate (
291
+ zip (inp .type .shape , batch_shape , strict = False )
292
+ )
284
293
if inp_s == 1 and out_s != 1
285
294
]
286
295
if to_sum :
@@ -320,9 +329,14 @@ def _check_runtime_broadcast(self, node, inputs):
320
329
321
330
for dims_and_bcast in zip (
322
331
* [
323
- zip (input .shape [:batch_ndim ], sinput .type .broadcastable [:batch_ndim ])
324
- for input , sinput in zip (inputs , node .inputs )
325
- ]
332
+ zip (
333
+ input .shape [:batch_ndim ],
334
+ sinput .type .broadcastable [:batch_ndim ],
335
+ strict = True ,
336
+ )
337
+ for input , sinput in zip (inputs , node .inputs , strict = True )
338
+ ],
339
+ strict = True ,
326
340
):
327
341
if any (d != 1 for d , _ in dims_and_bcast ) and (1 , False ) in dims_and_bcast :
328
342
raise ValueError (
@@ -343,7 +357,9 @@ def perform(self, node, inputs, output_storage):
343
357
if not isinstance (res , tuple ):
344
358
res = (res ,)
345
359
346
- for node_out , out_storage , r in zip (node .outputs , output_storage , res ):
360
+ for node_out , out_storage , r in zip (
361
+ node .outputs , output_storage , res , strict = True
362
+ ):
347
363
out_dtype = getattr (node_out , "dtype" , None )
348
364
if out_dtype and out_dtype != r .dtype :
349
365
r = np .asarray (r , dtype = out_dtype )
0 commit comments