@@ -385,11 +385,11 @@ def make_all(
385
385
f ,
386
386
[
387
387
Container (input , storage )
388
- for input , storage in zip (fgraph .inputs , input_storage )
388
+ for input , storage in zip (fgraph .inputs , input_storage , strict = True )
389
389
],
390
390
[
391
391
Container (output , storage , readonly = True )
392
- for output , storage in zip (fgraph .outputs , output_storage )
392
+ for output , storage in zip (fgraph .outputs , output_storage , strict = True )
393
393
],
394
394
thunks ,
395
395
order ,
@@ -509,7 +509,9 @@ def make_thunk(self, **kwargs):
509
509
kwargs .pop ("input_storage" , None )
510
510
make_all += [x .make_all (** kwargs ) for x in self .linkers [1 :]]
511
511
512
- fns , input_lists , output_lists , thunk_lists , order_lists = zip (* make_all )
512
+ fns , input_lists , output_lists , thunk_lists , order_lists = zip (
513
+ * make_all , strict = True
514
+ )
513
515
514
516
order_list0 = order_lists [0 ]
515
517
for order_list in order_lists [1 :]:
@@ -521,11 +523,11 @@ def make_thunk(self, **kwargs):
521
523
inputs0 = input_lists [0 ]
522
524
outputs0 = output_lists [0 ]
523
525
524
- thunk_groups = list (zip (* thunk_lists ))
525
- order = [x [0 ] for x in zip (* order_lists )]
526
+ thunk_groups = list (zip (* thunk_lists , strict = True ))
527
+ order = [x [0 ] for x in zip (* order_lists , strict = True )]
526
528
527
529
to_reset = []
528
- for thunks , node in zip (thunk_groups , order ):
530
+ for thunks , node in zip (thunk_groups , order , strict = True ):
529
531
for j , output in enumerate (node .outputs ):
530
532
if output in no_recycling :
531
533
for thunk in thunks :
@@ -536,12 +538,12 @@ def make_thunk(self, **kwargs):
536
538
537
539
def f ():
538
540
for inputs in input_lists [1 :]:
539
- for input1 , input2 in zip (inputs0 , inputs ):
541
+ for input1 , input2 in zip (inputs0 , inputs , strict = True ):
540
542
input2 .storage [0 ] = copy (input1 .storage [0 ])
541
543
for x in to_reset :
542
544
x [0 ] = None
543
545
pre (self , [input .data for input in input_lists [0 ]], order , thunk_groups )
544
- for i , (thunks , node ) in enumerate (zip (thunk_groups , order )):
546
+ for i , (thunks , node ) in enumerate (zip (thunk_groups , order , strict = True )):
545
547
try :
546
548
wrapper (self .fgraph , i , node , * thunks )
547
549
except Exception :
@@ -663,7 +665,9 @@ def thunk(
663
665
):
664
666
outputs = fgraph_jit (* [self .input_filter (x [0 ]) for x in thunk_inputs ])
665
667
666
- for o_var , o_storage , o_val in zip (fgraph .outputs , thunk_outputs , outputs ):
668
+ for o_var , o_storage , o_val in zip (
669
+ fgraph .outputs , thunk_outputs , outputs , strict = True
670
+ ):
667
671
compute_map [o_var ][0 ] = True
668
672
o_storage [0 ] = self .output_filter (o_var , o_val )
669
673
return outputs
@@ -731,11 +735,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
731
735
fn ,
732
736
[
733
737
Container (input , storage )
734
- for input , storage in zip (fgraph .inputs , input_storage )
738
+ for input , storage in zip (fgraph .inputs , input_storage , strict = True )
735
739
],
736
740
[
737
741
Container (output , storage , readonly = True )
738
- for output , storage in zip (fgraph .outputs , output_storage )
742
+ for output , storage in zip (fgraph .outputs , output_storage , strict = True )
739
743
],
740
744
thunks ,
741
745
nodes ,
0 commit comments