Commit 3b41cb5
[release/2.8] [Bugfix][Inductor] Fix dependency list merged incorrectly for a custo… (#2419)
…m op with multiple mutated inputs and None return type. (pytorch#157133)
This is an attempt to fix a memory allocation issue when using
`torch.compile` with a custom layernorm kernel in vllm:
```C++
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
```
We observed abnormal extra memory allocations with this op enabled using
`torch.compile`: <img width="738"
alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}"
src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5"
/> and without this op:
<img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}"
src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df"
/>
After investigation, we found that this is because the compiler
considers the two buffers for the two mutated inputs `Tensor input` and
`Tensor residual` should share a same dependency list, which makes it
can not reuse the buffer of `Tensor input`.
```
buf1.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
buf16.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
```
```
op13: ExternKernelSchedulerNode(FallbackKernel)
op13.writes =
[ StarDep(name='buf17', mode=None),
StarDep(name='buf18', mode=None),
StarDep(name='buf19', mode=None)]
op13.unmet_dependencies =
[ StarDep(name='buf13', mode=None),
StarDep(name='buf16', mode=None),
WeakDep(name='buf11', mutating_buf='buf18'),
WeakDep(name='buf12', mutating_buf='buf18'),
WeakDep(name='buf13', mutating_buf='buf18'),
WeakDep(name='buf2', mutating_buf='buf18'),
WeakDep(name='buf3', mutating_buf='buf18')]
op13.met_dependencies = [StarDep(name='arg11_1', mode=None)]
op13.outputs = [
buf17: FallbackKernel
buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf17.aliases = ['buf16', 'buf1']
buf17.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
buf18: MutationOutput
buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf18.mutations = ['buf16']
buf18.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True),
]
buf19: MutationOutput
buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf19.mutations = ['buf1']
buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)]
]
op13.node.kernel = torch.ops._C.fused_add_rms_norm.default
```
Here we can see `buf16` shares the same dependency list with `buf1`
because `buf16` and `buf1` are in the aliases list of `buf17`. This is
incorrect since those two are two separate tensors. And this makes the
compiler could not reuse `buf16` for subsequent ops.
Pull Request resolved: pytorch#157133
Approved by: https://github.com/jansel
(cherry picked from commit 02724b5)
Fixes #ISSUE_NUMBER
Co-authored-by: charlifu <[email protected]>1 parent 3b7f377 commit 3b41cb5
File tree
5 files changed
+72
-4
lines changed- test
- dynamo
- inductor
- torch
- _inductor
- _logging
5 files changed
+72
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
959 | 959 | | |
960 | 960 | | |
961 | 961 | | |
| 962 | + | |
962 | 963 | | |
963 | 964 | | |
964 | 965 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
445 | 445 | | |
446 | 446 | | |
447 | 447 | | |
448 | | - | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
449 | 456 | | |
450 | 457 | | |
451 | | - | |
452 | | - | |
453 | | - | |
| 458 | + | |
454 | 459 | | |
455 | 460 | | |
456 | 461 | | |
| |||
1733 | 1738 | | |
1734 | 1739 | | |
1735 | 1740 | | |
| 1741 | + | |
| 1742 | + | |
| 1743 | + | |
| 1744 | + | |
| 1745 | + | |
| 1746 | + | |
| 1747 | + | |
| 1748 | + | |
| 1749 | + | |
| 1750 | + | |
| 1751 | + | |
| 1752 | + | |
| 1753 | + | |
| 1754 | + | |
| 1755 | + | |
| 1756 | + | |
| 1757 | + | |
| 1758 | + | |
| 1759 | + | |
| 1760 | + | |
| 1761 | + | |
| 1762 | + | |
| 1763 | + | |
| 1764 | + | |
| 1765 | + | |
| 1766 | + | |
| 1767 | + | |
| 1768 | + | |
| 1769 | + | |
| 1770 | + | |
| 1771 | + | |
| 1772 | + | |
| 1773 | + | |
| 1774 | + | |
| 1775 | + | |
1736 | 1776 | | |
1737 | 1777 | | |
1738 | 1778 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
77 | 80 | | |
78 | 81 | | |
79 | 82 | | |
| |||
2278 | 2281 | | |
2279 | 2282 | | |
2280 | 2283 | | |
| 2284 | + | |
| 2285 | + | |
| 2286 | + | |
| 2287 | + | |
| 2288 | + | |
| 2289 | + | |
| 2290 | + | |
| 2291 | + | |
| 2292 | + | |
2281 | 2293 | | |
2282 | 2294 | | |
2283 | 2295 | | |
| |||
2445 | 2457 | | |
2446 | 2458 | | |
2447 | 2459 | | |
| 2460 | + | |
| 2461 | + | |
| 2462 | + | |
| 2463 | + | |
| 2464 | + | |
| 2465 | + | |
| 2466 | + | |
| 2467 | + | |
| 2468 | + | |
| 2469 | + | |
| 2470 | + | |
| 2471 | + | |
2448 | 2472 | | |
2449 | 2473 | | |
2450 | 2474 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
252 | 252 | | |
253 | 253 | | |
254 | 254 | | |
| 255 | + | |
255 | 256 | | |
256 | 257 | | |
257 | 258 | | |
| |||
565 | 566 | | |
566 | 567 | | |
567 | 568 | | |
| 569 | + | |
568 | 570 | | |
569 | 571 | | |
570 | 572 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
183 | 183 | | |
184 | 184 | | |
185 | 185 | | |
| 186 | + | |
186 | 187 | | |
187 | 188 | | |
188 | 189 | | |
| |||
0 commit comments