@@ -399,14 +399,14 @@ def print_profile(cls, stream, prof, level=0):
399
399
file = stream ,
400
400
)
401
401
ll = []
402
- for rewrite , nb_n in zip (rewrites , nb_nodes ):
402
+ for rewrite , nb_n in zip (rewrites , nb_nodes , strict = True ):
403
403
if hasattr (rewrite , "__name__" ):
404
404
name = rewrite .__name__
405
405
else :
406
406
name = rewrite .name
407
407
idx = rewrites .index (rewrite )
408
408
ll .append ((name , rewrite .__class__ .__name__ , idx , * nb_n ))
409
- lll = sorted (zip (prof , ll ), key = lambda a : a [0 ])
409
+ lll = sorted (zip (prof , ll , strict = True ), key = lambda a : a [0 ])
410
410
411
411
for t , rewrite in lll [::- 1 ]:
412
412
i = rewrite [2 ]
@@ -480,7 +480,7 @@ def merge_profile(prof1, prof2):
480
480
481
481
new_rewrite = SequentialGraphRewriter (* new_l )
482
482
new_nb_nodes = []
483
- for p1 , p2 in zip (prof1 [8 ], prof2 [8 ]):
483
+ for p1 , p2 in zip (prof1 [8 ], prof2 [8 ], strict = True ):
484
484
new_nb_nodes .append ((p1 [0 ] + p2 [0 ], p1 [1 ] + p2 [1 ]))
485
485
new_nb_nodes .extend (prof1 [8 ][len (new_nb_nodes ) :])
486
486
new_nb_nodes .extend (prof2 [8 ][len (new_nb_nodes ) :])
@@ -635,7 +635,7 @@ def process_node(self, fgraph, node):
635
635
636
636
inputs_match = all (
637
637
node_in is cand_in
638
- for node_in , cand_in in zip (node .inputs , candidate .inputs )
638
+ for node_in , cand_in in zip (node .inputs , candidate .inputs , strict = True )
639
639
)
640
640
641
641
if inputs_match and node .op == candidate .op :
@@ -649,6 +649,7 @@ def process_node(self, fgraph, node):
649
649
node .outputs ,
650
650
candidate .outputs ,
651
651
["merge" ] * len (node .outputs ),
652
+ strict = True ,
652
653
)
653
654
)
654
655
@@ -721,7 +722,9 @@ def apply(self, fgraph):
721
722
inputs_match = all (
722
723
node_in is cand_in
723
724
for node_in , cand_in in zip (
724
- var .owner .inputs , candidate_var .owner .inputs
725
+ var .owner .inputs ,
726
+ candidate_var .owner .inputs ,
727
+ strict = True ,
725
728
)
726
729
)
727
730
@@ -1440,7 +1443,7 @@ def transform(self, fgraph, node):
1440
1443
repl = self .op2 .make_node (* node .inputs )
1441
1444
if self .transfer_tags :
1442
1445
repl .tag = copy .copy (node .tag )
1443
- for output , new_output in zip (node .outputs , repl .outputs ):
1446
+ for output , new_output in zip (node .outputs , repl .outputs , strict = True ):
1444
1447
new_output .tag = copy .copy (output .tag )
1445
1448
return repl .outputs
1446
1449
@@ -1622,7 +1625,7 @@ def transform(self, fgraph, node, get_nodes=True):
1622
1625
continue
1623
1626
ret = self .transform (fgraph , real_node , get_nodes = False )
1624
1627
if ret is not False and ret is not None :
1625
- return dict (zip (real_node .outputs , ret ))
1628
+ return dict (zip (real_node .outputs , ret , strict = True ))
1626
1629
1627
1630
if node .op != self .op :
1628
1631
return False
@@ -1654,7 +1657,7 @@ def transform(self, fgraph, node, get_nodes=True):
1654
1657
len (node .outputs ) == len (ret .owner .outputs )
1655
1658
and all (
1656
1659
o .type .is_super (new_o .type )
1657
- for o , new_o in zip (node .outputs , ret .owner .outputs )
1660
+ for o , new_o in zip (node .outputs , ret .owner .outputs , strict = True )
1658
1661
)
1659
1662
):
1660
1663
return False
@@ -1946,7 +1949,7 @@ def process_node(
1946
1949
)
1947
1950
# None in the replacement mean that this variable isn't used
1948
1951
# and we want to remove it
1949
- for r , rnew in zip (old_vars , replacements ):
1952
+ for r , rnew in zip (old_vars , replacements , strict = True ):
1950
1953
if rnew is None and len (fgraph .clients [r ]) > 0 :
1951
1954
raise ValueError (
1952
1955
f"Node rewriter { node_rewriter } tried to remove a variable"
@@ -1956,7 +1959,7 @@ def process_node(
1956
1959
# the replacement
1957
1960
repl_pairs = [
1958
1961
(r , rnew )
1959
- for r , rnew in zip (old_vars , replacements )
1962
+ for r , rnew in zip (old_vars , replacements , strict = True )
1960
1963
if rnew is not r and rnew is not None
1961
1964
]
1962
1965
@@ -2651,17 +2654,23 @@ def print_profile(cls, stream, prof, level=0):
2651
2654
print (blanc , "Global, final, and clean up rewriters" , file = stream )
2652
2655
for i in range (len (loop_timing )):
2653
2656
print (blanc , f"Iter { int (i )} " , file = stream )
2654
- for o , prof in zip (rewrite .global_rewriters , global_sub_profs [i ]):
2657
+ for o , prof in zip (
2658
+ rewrite .global_rewriters , global_sub_profs [i ], strict = True
2659
+ ):
2655
2660
try :
2656
2661
o .print_profile (stream , prof , level + 2 )
2657
2662
except NotImplementedError :
2658
2663
print (blanc , "merge not implemented for " , o )
2659
- for o , prof in zip (rewrite .final_rewriters , final_sub_profs [i ]):
2664
+ for o , prof in zip (
2665
+ rewrite .final_rewriters , final_sub_profs [i ], strict = True
2666
+ ):
2660
2667
try :
2661
2668
o .print_profile (stream , prof , level + 2 )
2662
2669
except NotImplementedError :
2663
2670
print (blanc , "merge not implemented for " , o )
2664
- for o , prof in zip (rewrite .cleanup_rewriters , cleanup_sub_profs [i ]):
2671
+ for o , prof in zip (
2672
+ rewrite .cleanup_rewriters , cleanup_sub_profs [i ], strict = True
2673
+ ):
2665
2674
try :
2666
2675
o .print_profile (stream , prof , level + 2 )
2667
2676
except NotImplementedError :
@@ -2879,7 +2888,7 @@ def local_recursive_function(
2879
2888
outs , rewritten_vars = local_recursive_function (
2880
2889
rewrite_list , inp , rewritten_vars , depth + 1
2881
2890
)
2882
- for k , v in zip (inp .owner .outputs , outs ):
2891
+ for k , v in zip (inp .owner .outputs , outs , strict = True ):
2883
2892
rewritten_vars [k ] = v
2884
2893
nw_in = outs [inp .owner .outputs .index (inp )]
2885
2894
@@ -2897,7 +2906,7 @@ def local_recursive_function(
2897
2906
if ret is not False and ret is not None :
2898
2907
assert isinstance (ret , Sequence )
2899
2908
assert len (ret ) == len (node .outputs ), rewrite
2900
- for k , v in zip (node .outputs , ret ):
2909
+ for k , v in zip (node .outputs , ret , strict = True ):
2901
2910
rewritten_vars [k ] = v
2902
2911
results = ret
2903
2912
if ret [0 ].owner :
0 commit comments