Skip to content

Commit 8b22352

Browse files
authored
[SWDEV-531526] [SWDEV-527340] Allocation of buffers ordered before compute (#2276)
Ensure fused nodes that allocate buffers come before kernels that usethose buffers In one example we observed: - op8 creates buf10 which mutates buf8 - triton_poi_fused_index_put_lift_fresh_2 kernel tries to use buf8 and buf9 - op6_op7_op16 (fused node) creates buf8 and buf9 But the standard topological sort didn't ensure that the fused node creating buf8 and buf9 came before the kernel using them. After this PR we will identify op8 performs a mutation on buf8, find the node that is responsible for creating the buffer (op6_op7_op16) and add an explicit dependency so now op8 depends on op6_op7_op16 and orders graph accordingly. Note this issue is not seen in PT2.7, not clear as to why. We will hold back on upstreaming this until we observe a similar issue on nightly. Reproducer code (simplified from megatron) https://gist.github.com/jataylo/10bedef08323441c588d2965ad963ae8 Execute with > torchrun --nproc_per_node 1 repro.py Before PR ``` [rank0]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 466, in __call__ [rank0]: return self.current_callable(inputs) [rank0]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2128, in run [rank0]: return model(new_inputs) [rank0]: File "/tmp/torchinductor_root/gp/cgpe6weswyihhm442ugdhqxypbr7urxgk3adfr25onncik6tvthr.py", line 423, in call [rank0]: triton_poi_fused_index_put_lift_fresh_2.run(buf9, buf8, 256, grid=grid(256), stream=stream0) [rank0]: UnboundLocalError: local variable 'buf9' referenced before assignment ``` Note the simpler repro fails for both CUDA/ROCm and shows a logic issue across PT2.6, more details in gist.
1 parent 9d15d89 commit 8b22352

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

torch/_inductor/scheduler.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2247,21 +2247,43 @@ def topological_sort_schedule(
22472247
name_to_node: Dict[str, BaseSchedulerNode] = dict()
22482248
result: List[BaseSchedulerNode] = []
22492249

2250+
def has_mutations(node: BaseSchedulerNode) -> bool:
2251+
return any(buf.get_mutations() for buf in node.get_outputs())
2252+
22502253
def visit(n: BaseSchedulerNode) -> None:
22512254
if n not in seen:
22522255
seen.add(n)
2256+
2257+
# Visit regular dependencies
22532258
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
22542259
# We only care about doing toposort within `nodes`
22552260
if dep.name not in name_to_node:
22562261
continue
22572262
visit(name_to_node[dep.name])
2263+
2264+
# Visit mutation dependencies
2265+
for buf in n.get_outputs():
2266+
for mutation in buf.get_mutations():
2267+
if mutation in name_to_node and name_to_node[mutation] != n:
2268+
visit(name_to_node[mutation])
2269+
22582270
result.append(n)
22592271

2272+
# Build name to node mapping
22602273
for node in nodes:
22612274
for name in node.get_buffer_names():
22622275
name_to_node[name] = node
2276+
2277+
# Visit non-mutation nodes first
2278+
for node in nodes:
2279+
if not has_mutations(node):
2280+
visit(node)
2281+
2282+
# Then visit mutation nodes
22632283
for node in nodes:
2264-
visit(node)
2284+
if has_mutations(node):
2285+
visit(node)
2286+
22652287
return result
22662288

22672289
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:

0 commit comments

Comments
 (0)