@@ -385,11 +385,11 @@ def make_all(
385
385
f ,
386
386
[
387
387
Container (input , storage )
388
- for input , storage in zip (fgraph .inputs , input_storage )
388
+ for input , storage in zip (fgraph .inputs , input_storage , strict = True )
389
389
],
390
390
[
391
391
Container (output , storage , readonly = True )
392
- for output , storage in zip (fgraph .outputs , output_storage )
392
+ for output , storage in zip (fgraph .outputs , output_storage , strict = True )
393
393
],
394
394
thunks ,
395
395
order ,
@@ -509,7 +509,9 @@ def make_thunk(self, **kwargs):
509
509
kwargs .pop ("input_storage" , None )
510
510
make_all += [x .make_all (** kwargs ) for x in self .linkers [1 :]]
511
511
512
- fns , input_lists , output_lists , thunk_lists , order_lists = zip (* make_all )
512
+ fns , input_lists , output_lists , thunk_lists , order_lists = zip (
513
+ * make_all , strict = True
514
+ )
513
515
514
516
order_list0 = order_lists [0 ]
515
517
for order_list in order_lists [1 :]:
@@ -521,12 +523,12 @@ def make_thunk(self, **kwargs):
521
523
inputs0 = input_lists [0 ]
522
524
outputs0 = output_lists [0 ]
523
525
524
- thunk_groups = list (zip (* thunk_lists ))
525
- order = [x [0 ] for x in zip (* order_lists )]
526
+ thunk_groups = list (zip (* thunk_lists , strict = True ))
527
+ order = [x [0 ] for x in zip (* order_lists , strict = True )]
526
528
527
529
to_reset = [
528
530
thunk .outputs [j ]
529
- for thunks , node in zip (thunk_groups , order )
531
+ for thunks , node in zip (thunk_groups , order , strict = True )
530
532
for j , output in enumerate (node .outputs )
531
533
if output in no_recycling
532
534
for thunk in thunks
@@ -537,12 +539,12 @@ def make_thunk(self, **kwargs):
537
539
538
540
def f ():
539
541
for inputs in input_lists [1 :]:
540
- for input1 , input2 in zip (inputs0 , inputs ):
542
+ for input1 , input2 in zip (inputs0 , inputs , strict = True ):
541
543
input2 .storage [0 ] = copy (input1 .storage [0 ])
542
544
for x in to_reset :
543
545
x [0 ] = None
544
546
pre (self , [input .data for input in input_lists [0 ]], order , thunk_groups )
545
- for i , (thunks , node ) in enumerate (zip (thunk_groups , order )):
547
+ for i , (thunks , node ) in enumerate (zip (thunk_groups , order , strict = True )):
546
548
try :
547
549
wrapper (self .fgraph , i , node , * thunks )
548
550
except Exception :
@@ -664,7 +666,9 @@ def thunk(
664
666
):
665
667
outputs = fgraph_jit (* [self .input_filter (x [0 ]) for x in thunk_inputs ])
666
668
667
- for o_var , o_storage , o_val in zip (fgraph .outputs , thunk_outputs , outputs ):
669
+ for o_var , o_storage , o_val in zip (
670
+ fgraph .outputs , thunk_outputs , outputs , strict = True
671
+ ):
668
672
compute_map [o_var ][0 ] = True
669
673
o_storage [0 ] = self .output_filter (o_var , o_val )
670
674
return outputs
@@ -730,11 +734,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
730
734
fn ,
731
735
[
732
736
Container (input , storage )
733
- for input , storage in zip (fgraph .inputs , input_storage )
737
+ for input , storage in zip (fgraph .inputs , input_storage , strict = True )
734
738
],
735
739
[
736
740
Container (output , storage , readonly = True )
737
- for output , storage in zip (fgraph .outputs , output_storage )
741
+ for output , storage in zip (fgraph .outputs , output_storage , strict = True )
738
742
],
739
743
thunks ,
740
744
nodes ,
0 commit comments