@@ -2714,12 +2714,22 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27142714 if (
27152715 used_non_deterministic_runtime_estimations ()
27162716 and config_comms .runtime_estimations_align_across_all_distributed_ranks
2717- ):
2718- from . comms import (
2719- align_runtime_estimations_across_all_distributed_ranks ,
2717+ and (
2718+ config . runtime_estimations_mms_benchmark
2719+ or config_comms . runtime_estimations_use_nccl_lib_estimations
27202720 )
2721+ ):
2722+ has_collectives = False
2723+ for node in self .nodes :
2724+ if is_collective (node .node ):
2725+ has_collectives = True
2726+ break
2727+ if has_collectives :
2728+ from .comms import (
2729+ align_runtime_estimations_across_all_distributed_ranks ,
2730+ )
27212731
2722- align_runtime_estimations_across_all_distributed_ranks (self .nodes )
2732+ align_runtime_estimations_across_all_distributed_ranks (self .nodes )
27232733
27242734 from torch ._logging import trace_structured
27252735
@@ -2742,8 +2752,11 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27422752 self .process_grouped_nodes ()
27432753
27442754 if (
2755+ # pyrefly: ignore[unbound-name]
27452756 config .graph_partition
2757+ # pyrefly: ignore[unbound-name]
27462758 and config .triton .cudagraphs
2759+ # pyrefly: ignore[unbound-name]
27472760 and config .triton .reorder_for_reducing_graph_partitions
27482761 ):
27492762 self .nodes = self .maybe_reorder_for_minimizing_partition (self .nodes )
@@ -2755,6 +2768,7 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27552768 self .insert_memory_check_nodes ()
27562769
27572770 log_ir_post_fusion (self .nodes )
2771+ # pyrefly: ignore[unbound-name]
27582772 V .debug .graph_diagram (self .nodes )
27592773 self .debug_draw_graph ()
27602774
0 commit comments