Skip to content

Commit 6fa7791

Browse files
Microvepytorchmergebot
authored andcommitted
Reland"Fix different seq length (pytorch#167481)" (pytorch#168144)
Differential Revision: D87413883 Pull Request resolved: pytorch#168144 Approved by: https://github.com/eellison
1 parent c614128 commit 6fa7791

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

torch/_inductor/scheduler.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2714,12 +2714,22 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27142714
if (
27152715
used_non_deterministic_runtime_estimations()
27162716
and config_comms.runtime_estimations_align_across_all_distributed_ranks
2717-
):
2718-
from .comms import (
2719-
align_runtime_estimations_across_all_distributed_ranks,
2717+
and (
2718+
config.runtime_estimations_mms_benchmark
2719+
or config_comms.runtime_estimations_use_nccl_lib_estimations
27202720
)
2721+
):
2722+
has_collectives = False
2723+
for node in self.nodes:
2724+
if is_collective(node.node):
2725+
has_collectives = True
2726+
break
2727+
if has_collectives:
2728+
from .comms import (
2729+
align_runtime_estimations_across_all_distributed_ranks,
2730+
)
27212731

2722-
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
2732+
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
27232733

27242734
from torch._logging import trace_structured
27252735

@@ -2742,8 +2752,11 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27422752
self.process_grouped_nodes()
27432753

27442754
if (
2755+
# pyrefly: ignore[unbound-name]
27452756
config.graph_partition
2757+
# pyrefly: ignore[unbound-name]
27462758
and config.triton.cudagraphs
2759+
# pyrefly: ignore[unbound-name]
27472760
and config.triton.reorder_for_reducing_graph_partitions
27482761
):
27492762
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
@@ -2755,6 +2768,7 @@ def _init(self, nodes: list[ir.Operation]) -> None:
27552768
self.insert_memory_check_nodes()
27562769

27572770
log_ir_post_fusion(self.nodes)
2771+
# pyrefly: ignore[unbound-name]
27582772
V.debug.graph_diagram(self.nodes)
27592773
self.debug_draw_graph()
27602774

0 commit comments

Comments
 (0)