forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit 7116e0a
authored
Tag mutated buffer for AOTI cuda partitioner (pytorch#14783)
This should avoid having to copy mutated buffer back to outputs.
Before PR I'm getting this graph:
```
graph():
%b_key_cache_0 : [num_users=1] = placeholder[target=b_key_cache_0]
%b_value_cache_0 : [num_users=1] = placeholder[target=b_value_cache_0]
%b_key_cache_1 : [num_users=1] = placeholder[target=b_key_cache_1]
%b_value_cache_1 : [num_users=1] = placeholder[target=b_value_cache_1]
%b_key_cache_2 : [num_users=1] = placeholder[target=b_key_cache_2]
%b_value_cache_2 : [num_users=1] = placeholder[target=b_value_cache_2]
%b_key_cache_3 : [num_users=1] = placeholder[target=b_key_cache_3]
%b_value_cache_3 : [num_users=1] = placeholder[target=b_value_cache_3]
...
%b_key_cache_29 : [num_users=1] = placeholder[target=b_key_cache_29]
%b_value_cache_29 : [num_users=1] = placeholder[target=b_value_cache_29]
%inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
%executorch_call_delegate : [num_users=61] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position, %b_value_cache_0, %b_key_cache_0, %b_value_cache_1, %b_key_cache_1, %b_value_cache_2, %b_key_cache_2, %b_value_cache_3, %b_key_cache_3, %b_value_cache_4, %b_key_cache_4, %b_value_cache_5, %b_key_cache_5, %b_value_cache_6, %b_key_cache_6, %b_value_cache_7, %b_key_cache_7, %b_value_cache_8, %b_key_cache_8, %b_value_cache_9, %b_key_cache_9, %b_value_cache_10, %b_key_cache_10, %b_value_cache_11, %b_key_cache_11, %b_value_cache_12, %b_key_cache_12, %b_value_cache_13, %b_key_cache_13, %b_value_cache_14, %b_key_cache_14, %b_value_cache_15, %b_key_cache_15, %b_value_cache_16, %b_key_cache_16, %b_value_cache_17, %b_key_cache_17, %b_value_cache_18, %b_key_cache_18, %b_value_cache_19, %b_key_cache_19, %b_value_cache_20, %b_key_cache_20, %b_value_cache_21, %b_key_cache_21, %b_value_cache_22, %b_key_cache_22, %b_value_cache_23, %b_key_cache_23, %b_value_cache_24, %b_key_cache_24, %b_value_cache_25, %b_key_cache_25, %b_value_cache_26, %b_key_cache_26, %b_value_cache_27, %b_key_cache_27, %b_value_cache_28, %b_key_cache_28, %b_value_cache_29, %b_key_cache_29), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 1), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 2), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 4), kwargs = {})
...
%getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 60), kwargs = {})
return (getitem_1, getitem, getitem_3, getitem_2, getitem_5, getitem_4, getitem_7, getitem_6, getitem_9, getitem_8, getitem_11, getitem_10, getitem_13, getitem_12, getitem_15, getitem_14, getitem_17, getitem_16, getitem_19, getitem_18, getitem_21, getitem_20, getitem_23, getitem_22, getitem_25, getitem_24, getitem_27, getitem_26, getitem_29, getitem_28, getitem_31, getitem_30, getitem_33, getitem_32, getitem_35, getitem_34, getitem_37, getitem_36, getitem_39, getitem_38, getitem_41, getitem_40, getitem_43, getitem_42, getitem_45, getitem_44, getitem_47, getitem_46, getitem_49, getitem_48, getitem_51, getitem_50, getitem_53, getitem_52, getitem_55, getitem_54, getitem_57, getitem_56, getitem_59, getitem_58, getitem_60)/home/larryliu/.conda/envs/executorch/lib/python3.11/site-packages/executorch/exir/emit/_emitter.py:1595: UserWarning: Mutation on a buffer in the model is detected. ExecuTorch assumes buffers that are mutated in the graph have a meaningless initial state, only the shape and dtype will be serialized, unless a pass which sets meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.
warnings.warn(
```
This is unncessary because we don't want the kv cache as output.
After applying this PR I'm getting this graph instead:
```
graph():
%inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
%executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position), kwargs = {})
%getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
return (getitem_60,)
```
### Summary
[PLEASE REMOVE] See [CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests)
for ExecuTorch PR guidelines.
[PLEASE REMOVE] If this PR closes an issue, please add a `Fixes
#<issue-id>` line.
[PLEASE REMOVE] If this PR introduces a fix or feature that should be
the upcoming release notes, please add a "Release notes: <area>" label.
For a list of available release notes labels, check out
[CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests).
### Test plan
[PLEASE REMOVE] How did you test this PR? Please write down any manual
commands you used and note down tests that you have written if
applicable.1 parent 4d681cb commit 7116e0aCopy full SHA for 7116e0a
File tree
Expand file treeCollapse file tree
1 file changed
+2
-1
lines changedFilter options
- backends/cuda
Expand file treeCollapse file tree
1 file changed
+2
-1
lines changedCollapse file: backends/cuda/cuda_partitioner.py
backends/cuda/cuda_partitioner.py
Copy file name to clipboardExpand all lines: backends/cuda/cuda_partitioner.py+2-1Lines changed: 2 additions & 1 deletion
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
15 | 15 |
| |
16 | 16 |
| |
17 | 17 |
| |
18 |
| - | |
| 18 | + | |
19 | 19 |
| |
20 | 20 |
| |
21 | 21 |
| |
| |||
54 | 54 |
| |
55 | 55 |
| |
56 | 56 |
| |
| 57 | + | |
57 | 58 |
| |
58 | 59 |
| |
59 | 60 |
| |
|
0 commit comments