Commit 8b22352
authored
[SWDEV-531526] [SWDEV-527340] Allocation of buffers ordered before compute (#2276)
Ensure fused nodes that allocate buffers come before kernels that
usethose buffers
In one example we observed:
- op8 creates buf10 which mutates buf8
- triton_poi_fused_index_put_lift_fresh_2 kernel tries to use buf8 and
buf9
- op6_op7_op16 (fused node) creates buf8 and buf9
But the standard topological sort didn't ensure that the fused node
creating buf8 and buf9 came before the kernel using them.
After this PR we will identify op8 performs a mutation on buf8, find the
node that is responsible for creating the buffer (op6_op7_op16) and add
an explicit dependency so now op8 depends on op6_op7_op16 and orders
graph accordingly.
Note this issue is not seen in PT2.7, not clear as to why. We will hold
back on upstreaming this until we observe a similar issue on nightly.
Reproducer code (simplified from megatron)
https://gist.github.com/jataylo/10bedef08323441c588d2965ad963ae8
Execute with
> torchrun --nproc_per_node 1 repro.py
Before PR
```
[rank0]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 466, in __call__
[rank0]: return self.current_callable(inputs)
[rank0]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2128, in run
[rank0]: return model(new_inputs)
[rank0]: File "/tmp/torchinductor_root/gp/cgpe6weswyihhm442ugdhqxypbr7urxgk3adfr25onncik6tvthr.py", line 423, in call
[rank0]: triton_poi_fused_index_put_lift_fresh_2.run(buf9, buf8, 256, grid=grid(256), stream=stream0)
[rank0]: UnboundLocalError: local variable 'buf9' referenced before assignment
```
Note the simpler repro fails for both CUDA/ROCm and shows a logic issue
across PT2.6, more details in gist.1 parent 9d15d89 commit 8b22352
1 file changed
+23
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2247 | 2247 | | |
2248 | 2248 | | |
2249 | 2249 | | |
| 2250 | + | |
| 2251 | + | |
| 2252 | + | |
2250 | 2253 | | |
2251 | 2254 | | |
2252 | 2255 | | |
| 2256 | + | |
| 2257 | + | |
2253 | 2258 | | |
2254 | 2259 | | |
2255 | 2260 | | |
2256 | 2261 | | |
2257 | 2262 | | |
| 2263 | + | |
| 2264 | + | |
| 2265 | + | |
| 2266 | + | |
| 2267 | + | |
| 2268 | + | |
| 2269 | + | |
2258 | 2270 | | |
2259 | 2271 | | |
| 2272 | + | |
2260 | 2273 | | |
2261 | 2274 | | |
2262 | 2275 | | |
| 2276 | + | |
| 2277 | + | |
| 2278 | + | |
| 2279 | + | |
| 2280 | + | |
| 2281 | + | |
| 2282 | + | |
2263 | 2283 | | |
2264 | | - | |
| 2284 | + | |
| 2285 | + | |
| 2286 | + | |
2265 | 2287 | | |
2266 | 2288 | | |
2267 | 2289 | | |
| |||
0 commit comments