Commit 7116e0a
authored
Tag mutated buffer for AOTI cuda partitioner (#14783)
This should avoid having to copy mutated buffer back to outputs.
Before PR I'm getting this graph:
```
graph():
%b_key_cache_0 : [num_users=1] = placeholder[target=b_key_cache_0]
%b_value_cache_0 : [num_users=1] = placeholder[target=b_value_cache_0]
%b_key_cache_1 : [num_users=1] = placeholder[target=b_key_cache_1]
%b_value_cache_1 : [num_users=1] = placeholder[target=b_value_cache_1]
%b_key_cache_2 : [num_users=1] = placeholder[target=b_key_cache_2]
%b_value_cache_2 : [num_users=1] = placeholder[target=b_value_cache_2]
%b_key_cache_3 : [num_users=1] = placeholder[target=b_key_cache_3]
%b_value_cache_3 : [num_users=1] = placeholder[target=b_value_cache_3]
...
%b_key_cache_29 : [num_users=1] = placeholder[target=b_key_cache_29]
%b_value_cache_29 : [num_users=1] = placeholder[target=b_value_cache_29]
%inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
%executorch_call_delegate : [num_users=61] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position, %b_value_cache_0, %b_key_cache_0, %b_value_cache_1, %b_key_cache_1, %b_value_cache_2, %b_key_cache_2, %b_value_cache_3, %b_key_cache_3, %b_value_cache_4, %b_key_cache_4, %b_value_cache_5, %b_key_cache_5, %b_value_cache_6, %b_key_cache_6, %b_value_cache_7, %b_key_cache_7, %b_value_cache_8, %b_key_cache_8, %b_value_cache_9, %b_key_cache_9, %b_value_cache_10, %b_key_cache_10, %b_value_cache_11, %b_key_cache_11, %b_value_cache_12, %b_key_cache_12, %b_value_cache_13, %b_key_cache_13, %b_value_cache_14, %b_key_cache_14, %b_value_cache_15, %b_key_cache_15, %b_value_cache_16, %b_key_cache_16, %b_value_cache_17, %b_key_cache_17, %b_value_cache_18, %b_key_cache_18, %b_value_cache_19, %b_key_cache_19, %b_value_cache_20, %b_key_cache_20, %b_value_cache_21, %b_key_cache_21, %b_value_cache_22, %b_key_cache_22, %b_value_cache_23, %b_key_cache_23, %b_value_cache_24, %b_key_cache_24, %b_value_cache_25, %b_key_cache_25, %b_value_cache_26, %b_key_cache_26, %b_value_cache_27, %b_key_cache_27, %b_value_cache_28, %b_key_cache_28, %b_value_cache_29, %b_key_cache_29), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 1), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 2), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 4), kwargs = {})
...
%getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 60), kwargs = {})
return (getitem_1, getitem, getitem_3, getitem_2, getitem_5, getitem_4, getitem_7, getitem_6, getitem_9, getitem_8, getitem_11, getitem_10, getitem_13, getitem_12, getitem_15, getitem_14, getitem_17, getitem_16, getitem_19, getitem_18, getitem_21, getitem_20, getitem_23, getitem_22, getitem_25, getitem_24, getitem_27, getitem_26, getitem_29, getitem_28, getitem_31, getitem_30, getitem_33, getitem_32, getitem_35, getitem_34, getitem_37, getitem_36, getitem_39, getitem_38, getitem_41, getitem_40, getitem_43, getitem_42, getitem_45, getitem_44, getitem_47, getitem_46, getitem_49, getitem_48, getitem_51, getitem_50, getitem_53, getitem_52, getitem_55, getitem_54, getitem_57, getitem_56, getitem_59, getitem_58, getitem_60)/home/larryliu/.conda/envs/executorch/lib/python3.11/site-packages/executorch/exir/emit/_emitter.py:1595: UserWarning: Mutation on a buffer in the model is detected. ExecuTorch assumes buffers that are mutated in the graph have a meaningless initial state, only the shape and dtype will be serialized, unless a pass which sets meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.
warnings.warn(
```
This is unncessary because we don't want the kv cache as output.
After applying this PR I'm getting this graph instead:
```
graph():
%inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
%executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position), kwargs = {})
%getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
return (getitem_60,)
```
### Summary
[PLEASE REMOVE] See [CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests)
for ExecuTorch PR guidelines.
[PLEASE REMOVE] If this PR closes an issue, please add a `Fixes
#<issue-id>` line.
[PLEASE REMOVE] If this PR introduces a fix or feature that should be
the upcoming release notes, please add a "Release notes: <area>" label.
For a list of available release notes labels, check out
[CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests).
### Test plan
[PLEASE REMOVE] How did you test this PR? Please write down any manual
commands you used and note down tests that you have written if
applicable.1 parent 4d681cb commit 7116e0a
1 file changed
+2
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
| 18 | + | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| |||
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
| 57 | + | |
57 | 58 | | |
58 | 59 | | |
59 | 60 | | |
| |||
0 commit comments