@@ -399,14 +399,14 @@ def print_profile(cls, stream, prof, level=0):
399
399
file = stream ,
400
400
)
401
401
ll = []
402
- for rewrite , nb_n in zip (rewrites , nb_nodes ):
402
+ for rewrite , nb_n in zip (rewrites , nb_nodes , strict = True ):
403
403
if hasattr (rewrite , "__name__" ):
404
404
name = rewrite .__name__
405
405
else :
406
406
name = rewrite .name
407
407
idx = rewrites .index (rewrite )
408
408
ll .append ((name , rewrite .__class__ .__name__ , idx , * nb_n ))
409
- lll = sorted (zip (prof , ll ), key = lambda a : a [0 ])
409
+ lll = sorted (zip (prof , ll , strict = True ), key = lambda a : a [0 ])
410
410
411
411
for t , rewrite in lll [::- 1 ]:
412
412
i = rewrite [2 ]
@@ -480,7 +480,8 @@ def merge_profile(prof1, prof2):
480
480
481
481
new_rewrite = SequentialGraphRewriter (* new_l )
482
482
new_nb_nodes = [
483
- (p1 [0 ] + p2 [0 ], p1 [1 ] + p2 [1 ]) for p1 , p2 in zip (prof1 [8 ], prof2 [8 ])
483
+ (p1 [0 ] + p2 [0 ], p1 [1 ] + p2 [1 ])
484
+ for p1 , p2 in zip (prof1 [8 ], prof2 [8 ], strict = True )
484
485
]
485
486
new_nb_nodes .extend (prof1 [8 ][len (new_nb_nodes ) :])
486
487
new_nb_nodes .extend (prof2 [8 ][len (new_nb_nodes ) :])
@@ -635,7 +636,7 @@ def process_node(self, fgraph, node):
635
636
636
637
inputs_match = all (
637
638
node_in is cand_in
638
- for node_in , cand_in in zip (node .inputs , candidate .inputs )
639
+ for node_in , cand_in in zip (node .inputs , candidate .inputs , strict = True )
639
640
)
640
641
641
642
if inputs_match and node .op == candidate .op :
@@ -649,6 +650,7 @@ def process_node(self, fgraph, node):
649
650
node .outputs ,
650
651
candidate .outputs ,
651
652
["merge" ] * len (node .outputs ),
653
+ strict = True ,
652
654
)
653
655
)
654
656
@@ -721,7 +723,9 @@ def apply(self, fgraph):
721
723
inputs_match = all (
722
724
node_in is cand_in
723
725
for node_in , cand_in in zip (
724
- var .owner .inputs , candidate_var .owner .inputs
726
+ var .owner .inputs ,
727
+ candidate_var .owner .inputs ,
728
+ strict = True ,
725
729
)
726
730
)
727
731
@@ -1434,7 +1438,7 @@ def transform(self, fgraph, node):
1434
1438
repl = self .op2 .make_node (* node .inputs )
1435
1439
if self .transfer_tags :
1436
1440
repl .tag = copy .copy (node .tag )
1437
- for output , new_output in zip (node .outputs , repl .outputs ):
1441
+ for output , new_output in zip (node .outputs , repl .outputs , strict = True ):
1438
1442
new_output .tag = copy .copy (output .tag )
1439
1443
return repl .outputs
1440
1444
@@ -1616,7 +1620,7 @@ def transform(self, fgraph, node, get_nodes=True):
1616
1620
continue
1617
1621
ret = self .transform (fgraph , real_node , get_nodes = False )
1618
1622
if ret is not False and ret is not None :
1619
- return dict (zip (real_node .outputs , ret ))
1623
+ return dict (zip (real_node .outputs , ret , strict = True ))
1620
1624
1621
1625
if node .op != self .op :
1622
1626
return False
@@ -1648,7 +1652,7 @@ def transform(self, fgraph, node, get_nodes=True):
1648
1652
len (node .outputs ) == len (ret .owner .outputs )
1649
1653
and all (
1650
1654
o .type .is_super (new_o .type )
1651
- for o , new_o in zip (node .outputs , ret .owner .outputs )
1655
+ for o , new_o in zip (node .outputs , ret .owner .outputs , strict = True )
1652
1656
)
1653
1657
):
1654
1658
return False
@@ -1940,7 +1944,7 @@ def process_node(
1940
1944
)
1941
1945
# None in the replacement mean that this variable isn't used
1942
1946
# and we want to remove it
1943
- for r , rnew in zip (old_vars , replacements ):
1947
+ for r , rnew in zip (old_vars , replacements , strict = True ):
1944
1948
if rnew is None and len (fgraph .clients [r ]) > 0 :
1945
1949
raise ValueError (
1946
1950
f"Node rewriter { node_rewriter } tried to remove a variable"
@@ -1950,7 +1954,7 @@ def process_node(
1950
1954
# the replacement
1951
1955
repl_pairs = [
1952
1956
(r , rnew )
1953
- for r , rnew in zip (old_vars , replacements )
1957
+ for r , rnew in zip (old_vars , replacements , strict = True )
1954
1958
if rnew is not r and rnew is not None
1955
1959
]
1956
1960
@@ -2633,17 +2637,23 @@ def print_profile(cls, stream, prof, level=0):
2633
2637
print (blanc , "Global, final, and clean up rewriters" , file = stream )
2634
2638
for i in range (len (loop_timing )):
2635
2639
print (blanc , f"Iter { int (i )} " , file = stream )
2636
- for o , prof in zip (rewrite .global_rewriters , global_sub_profs [i ]):
2640
+ for o , prof in zip (
2641
+ rewrite .global_rewriters , global_sub_profs [i ], strict = True
2642
+ ):
2637
2643
try :
2638
2644
o .print_profile (stream , prof , level + 2 )
2639
2645
except NotImplementedError :
2640
2646
print (blanc , "merge not implemented for " , o )
2641
- for o , prof in zip (rewrite .final_rewriters , final_sub_profs [i ]):
2647
+ for o , prof in zip (
2648
+ rewrite .final_rewriters , final_sub_profs [i ], strict = True
2649
+ ):
2642
2650
try :
2643
2651
o .print_profile (stream , prof , level + 2 )
2644
2652
except NotImplementedError :
2645
2653
print (blanc , "merge not implemented for " , o )
2646
- for o , prof in zip (rewrite .cleanup_rewriters , cleanup_sub_profs [i ]):
2654
+ for o , prof in zip (
2655
+ rewrite .cleanup_rewriters , cleanup_sub_profs [i ], strict = True
2656
+ ):
2647
2657
try :
2648
2658
o .print_profile (stream , prof , level + 2 )
2649
2659
except NotImplementedError :
@@ -2861,7 +2871,7 @@ def local_recursive_function(
2861
2871
outs , rewritten_vars = local_recursive_function (
2862
2872
rewrite_list , inp , rewritten_vars , depth + 1
2863
2873
)
2864
- for k , v in zip (inp .owner .outputs , outs ):
2874
+ for k , v in zip (inp .owner .outputs , outs , strict = True ):
2865
2875
rewritten_vars [k ] = v
2866
2876
nw_in = outs [inp .owner .outputs .index (inp )]
2867
2877
@@ -2879,7 +2889,7 @@ def local_recursive_function(
2879
2889
if ret is not False and ret is not None :
2880
2890
assert isinstance (ret , Sequence )
2881
2891
assert len (ret ) == len (node .outputs ), rewrite
2882
- for k , v in zip (node .outputs , ret ):
2892
+ for k , v in zip (node .outputs , ret , strict = True ):
2883
2893
rewritten_vars [k ] = v
2884
2894
results = ret
2885
2895
if ret [0 ].owner :
0 commit comments