@@ -309,8 +309,8 @@ def log_details(self) -> None:
309309
310310 def reorder_loops_by_dep_pair (
311311 self , self_dep : MemoryDep , other_dep : MemoryDep
312- ) -> None :
313- return
312+ ) -> bool :
313+ return False
314314
315315 def update_mutated_names (self , renames : dict [str , str ]) -> None :
316316 self .mutation_renames = {
@@ -1130,6 +1130,11 @@ def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
11301130
11311131
11321132class SchedulerNode (BaseSchedulerNode ):
1133+ """
1134+ A SchedulerNode is a node for scheduling that encapsulates either
1135+ a ComputedBuffer or a TemplateBuffer.
1136+ """
1137+
11331138 _sizes : tuple [Sequence [sympy .Expr ], ...]
11341139 _body : LoopBody
11351140
@@ -1254,7 +1259,7 @@ def merge_loops(self) -> None:
12541259
12551260 def reorder_loops_by_dep_pair (
12561261 self , self_dep : MemoryDep , other_dep : MemoryDep
1257- ) -> None :
1262+ ) -> bool :
12581263 new_order = None
12591264 self_sizes = self ._sizes [0 ]
12601265 if len (self_sizes ) == self_dep .num_vars == other_dep .num_vars :
@@ -1266,11 +1271,13 @@ def reorder_loops_by_dep_pair(
12661271 "Reorder loops for %s with order %s" , self .get_name (), new_order
12671272 )
12681273 self .apply_new_loop_order (new_order )
1274+ return True
12691275 else :
12701276 loop_ordering_log .debug (
12711277 "Don't reordering %s because we can not decide the suitable loop order" ,
12721278 self .get_name (),
12731279 )
1280+ return False
12741281
12751282 def debug_str_extra (self ) -> str :
12761283 name = self .get_name ()
@@ -1527,18 +1534,21 @@ def estimate_flops(self) -> int | None:
15271534
15281535 def reorder_loops_by_dep_pair (
15291536 self , self_dep : MemoryDep , other_dep : MemoryDep
1530- ) -> None :
1537+ ) -> bool :
1538+ """
1539+ Return true if a loop reordering is performed.
1540+ """
15311541 if self .is_template ():
15321542 # We can not really reorder loops for a triton template
1533- return
1543+ return False
15341544 self_sizes = None
15351545 for snode in self .snodes :
15361546 assert isinstance (snode , SchedulerNode )
15371547 if self_sizes is not None and tuple (self_sizes ) != tuple (snode ._sizes [0 ]):
15381548 loop_ordering_log .debug (
15391549 "Can not reorder fused node due to different sizes"
15401550 )
1541- return
1551+ return False
15421552 self_sizes = snode ._sizes [0 ]
15431553 new_order = None
15441554
@@ -1551,7 +1561,7 @@ def reorder_loops_by_dep_pair(
15511561 "Dont reordering fused node %s because we can not decide the suitable loop order" ,
15521562 self .get_name (),
15531563 )
1554- return
1564+ return False
15551565 metrics .num_loop_reordering += 1
15561566 loop_ordering_log .debug (
15571567 "Reorder loops for fused node %s with order %s" , self .get_name (), new_order
@@ -1561,6 +1571,7 @@ def reorder_loops_by_dep_pair(
15611571 snode .apply_new_loop_order (new_order )
15621572
15631573 refresh_group_node_dependencies (self )
1574+ return True
15641575
15651576 def __init__ (self , scheduler : Scheduler , snodes : list [BaseSchedulerNode ]) -> None :
15661577 super ().__init__ (scheduler )
@@ -3900,21 +3911,26 @@ def shared_data_after_reordering_loop(
39003911 Right now just greedily reorder the loop of node1 to be compatible with node2,
39013912 but ideally we should have some heuristics to reorder the loop for node2
39023913 to be compatible with node1 if that's more efficient.
3914+
3915+ Return the amount of shared data re-computed in this method.
3916+ If no such recomputation happens, return -1 (not return 0 since 0 is a valid
3917+ amount of shared data).
3918+
39033919 """
39043920
39053921 # TODO Don't do loop reordering for CPU for now.
39063922 # Should debug more why it does not work for CPU codegen
39073923 if not config .loop_ordering_after_fusion or any (
39083924 n .is_cpu () for n in [node1 , node2 ]
39093925 ):
3910- return 0
3926+ return - 1
39113927
39123928 node1_buffer_names = node1 .read_writes .buffer_names ()
39133929 node2_buffer_names = node2 .read_writes .buffer_names ()
39143930 # Fast path: no common buffers.
39153931 common_buffer_names = node1_buffer_names & node2_buffer_names
39163932 if not common_buffer_names :
3917- return 0
3933+ return - 1
39183934
39193935 node1_name2dep = {dep .name : dep for dep in node1 .read_writes .reads_and_writes ()}
39203936 node2_name2dep = {dep .name : dep for dep in node2 .read_writes .reads_and_writes ()}
@@ -3937,13 +3953,13 @@ def shared_data_after_reordering_loop(
39373953 )
39383954
39393955 if len (candidates ) == 0 :
3940- return 0
3956+ return - 1
39413957
39423958 # Pick the largest buffer to guide the loop reordering
39433959 _numel , lhs_dep , rhs_dep = max (candidates , key = operator .itemgetter (0 ))
39443960
39453961 if not isinstance (lhs_dep , MemoryDep ) or not isinstance (rhs_dep , MemoryDep ):
3946- return 0
3962+ return - 1
39473963
39483964 if lhs_dep .num_vars != rhs_dep .num_vars :
39493965 # this can happen due to we don't merge loops.
@@ -3952,21 +3968,22 @@ def shared_data_after_reordering_loop(
39523968 # normalization (merging loops)
39533969 if lhs_dep .normalize () == rhs_dep .normalize ():
39543970 return self .dep_size_hint (lhs_dep )
3955- return 0
3971+ return - 1
39563972
3973+ reordered = False
39573974 # Only reorder loops for pointwise for now
39583975 if not node1 .is_reduction ():
3959- node1 .reorder_loops_by_dep_pair (lhs_dep , rhs_dep )
3976+ reordered = node1 .reorder_loops_by_dep_pair (lhs_dep , rhs_dep )
39603977 elif not node2 .is_reduction ():
3961- node2 .reorder_loops_by_dep_pair (rhs_dep , lhs_dep )
3978+ reordered = node2 .reorder_loops_by_dep_pair (rhs_dep , lhs_dep )
39623979 else :
39633980 loop_ordering_log .debug (
39643981 "Don't reorder loops since both nodes are reductions: %s v.s. %s" ,
39653982 node1 .get_name (),
39663983 node2 .get_name (),
39673984 )
39683985
3969- return self .score_fusion_memory (node1 , node2 )
3986+ return self .score_fusion_memory (node1 , node2 ) if reordered else - 1
39703987
39713988 def unfusable_node (self , node : BaseSchedulerNode ) -> bool :
39723989 """
@@ -4255,7 +4272,9 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
42554272 shared_data_score < config .score_fusion_memory_threshold
42564273 and config .loop_ordering_after_fusion
42574274 ):
4258- shared_data_score = self .shared_data_after_reordering_loop (node1 , node2 )
4275+ new_shared_data_score = self .shared_data_after_reordering_loop (node1 , node2 )
4276+ if new_shared_data_score >= 0 :
4277+ shared_data_score = new_shared_data_score
42594278
42604279 if config .expand_dimension_for_pointwise_nodes and (
42614280 expand_analysis := self .get_expand_dim_for_pointwise_nodes (node1 , node2 )
0 commit comments