Skip to content

Commit 7116e0a

Browse files
authored
Tag mutated buffer for AOTI cuda partitioner (#14783)
This should avoid having to copy mutated buffer back to outputs. Before PR I'm getting this graph: ``` graph(): %b_key_cache_0 : [num_users=1] = placeholder[target=b_key_cache_0] %b_value_cache_0 : [num_users=1] = placeholder[target=b_value_cache_0] %b_key_cache_1 : [num_users=1] = placeholder[target=b_key_cache_1] %b_value_cache_1 : [num_users=1] = placeholder[target=b_value_cache_1] %b_key_cache_2 : [num_users=1] = placeholder[target=b_key_cache_2] %b_value_cache_2 : [num_users=1] = placeholder[target=b_value_cache_2] %b_key_cache_3 : [num_users=1] = placeholder[target=b_key_cache_3] %b_value_cache_3 : [num_users=1] = placeholder[target=b_value_cache_3] ... %b_key_cache_29 : [num_users=1] = placeholder[target=b_key_cache_29] %b_value_cache_29 : [num_users=1] = placeholder[target=b_value_cache_29] %inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds] %cache_position : [num_users=1] = placeholder[target=cache_position] %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0] %executorch_call_delegate : [num_users=61] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position, %b_value_cache_0, %b_key_cache_0, %b_value_cache_1, %b_key_cache_1, %b_value_cache_2, %b_key_cache_2, %b_value_cache_3, %b_key_cache_3, %b_value_cache_4, %b_key_cache_4, %b_value_cache_5, %b_key_cache_5, %b_value_cache_6, %b_key_cache_6, %b_value_cache_7, %b_key_cache_7, %b_value_cache_8, %b_key_cache_8, %b_value_cache_9, %b_key_cache_9, %b_value_cache_10, %b_key_cache_10, %b_value_cache_11, %b_key_cache_11, %b_value_cache_12, %b_key_cache_12, %b_value_cache_13, %b_key_cache_13, %b_value_cache_14, %b_key_cache_14, %b_value_cache_15, %b_key_cache_15, %b_value_cache_16, %b_key_cache_16, %b_value_cache_17, %b_key_cache_17, %b_value_cache_18, %b_key_cache_18, %b_value_cache_19, %b_key_cache_19, %b_value_cache_20, %b_key_cache_20, %b_value_cache_21, %b_key_cache_21, %b_value_cache_22, %b_key_cache_22, %b_value_cache_23, %b_key_cache_23, %b_value_cache_24, %b_key_cache_24, %b_value_cache_25, %b_key_cache_25, %b_value_cache_26, %b_key_cache_26, %b_value_cache_27, %b_key_cache_27, %b_value_cache_28, %b_key_cache_28, %b_value_cache_29, %b_key_cache_29), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 1), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 2), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 3), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 4), kwargs = {}) ... %getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 60), kwargs = {}) return (getitem_1, getitem, getitem_3, getitem_2, getitem_5, getitem_4, getitem_7, getitem_6, getitem_9, getitem_8, getitem_11, getitem_10, getitem_13, getitem_12, getitem_15, getitem_14, getitem_17, getitem_16, getitem_19, getitem_18, getitem_21, getitem_20, getitem_23, getitem_22, getitem_25, getitem_24, getitem_27, getitem_26, getitem_29, getitem_28, getitem_31, getitem_30, getitem_33, getitem_32, getitem_35, getitem_34, getitem_37, getitem_36, getitem_39, getitem_38, getitem_41, getitem_40, getitem_43, getitem_42, getitem_45, getitem_44, getitem_47, getitem_46, getitem_49, getitem_48, getitem_51, getitem_50, getitem_53, getitem_52, getitem_55, getitem_54, getitem_57, getitem_56, getitem_59, getitem_58, getitem_60)/home/larryliu/.conda/envs/executorch/lib/python3.11/site-packages/executorch/exir/emit/_emitter.py:1595: UserWarning: Mutation on a buffer in the model is detected. ExecuTorch assumes buffers that are mutated in the graph have a meaningless initial state, only the shape and dtype will be serialized, unless a pass which sets meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run. warnings.warn( ``` This is unncessary because we don't want the kv cache as output. After applying this PR I'm getting this graph instead: ``` graph(): %inputs_embeds : [num_users=1] = placeholder[target=inputs_embeds] %cache_position : [num_users=1] = placeholder[target=cache_position] %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0] %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %inputs_embeds, %cache_position), kwargs = {}) %getitem_60 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {}) return (getitem_60,) ``` ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent 4d681cb commit 7116e0a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

backends/cuda/cuda_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Partitioner,
1616
PartitionResult,
1717
)
18-
from executorch.exir.backend.utils import tag_constant_data
18+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
1919
from torch.export.exported_program import ExportedProgram
2020

2121

@@ -54,6 +54,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5454
partition_tags[tag] = self.delegation_spec
5555

5656
tag_constant_data(exported_program)
57+
tag_mutated_buffer(exported_program)
5758

5859
return PartitionResult(
5960
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)