Commit 951aef9
committed
[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers
## Changes
* Move the logic skipping output nodes that are mutable buffers from runtime to AOT
## Context
A `fx.Graph` may return nodes that are mutable buffers:
```
class GraphModule(torch.nn.Module):
def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"):
sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1)
...
# b_wrapped_module_kv_cache_*_cache are mutable buffers
# getitem_2 and getitem_3 are derived from mutable buffers, hence they are
# themselves mutable buffers
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None
getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None
getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None
...
aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None
aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None
aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None
aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None
# getitem_2 and getitem_3 are returned as outputs, presumably to prevent the
# update_cache calls from being removed due to dead code elimination
return (getitem_2, getitem_3, aten_view_copy_default_11, None)
```
In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs
```
Graph signature:
# inputs
p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight'
p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight'
p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight'
p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight'
b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True
b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True
x: USER_INPUT
freqs_cos: USER_INPUT
freqs_sin: USER_INPUT
input_pos: USER_INPUT
# outputs
getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache'
getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache'
aten_view_copy_default_11: USER_OUTPUT
: USER_OUTPUT
```
Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime.
## Motivation
Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs.
To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs.
Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/)
ghstack-source-id: 292684341
Pull Request resolved: #119831 parent 910cc4e commit 951aef9
File tree
5 files changed
+34
-7
lines changed- backends/vulkan
- runtime
- graph
- serialization
5 files changed
+34
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
359 | 359 | | |
360 | 360 | | |
361 | 361 | | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
| 362 | + | |
| 363 | + | |
366 | 364 | | |
367 | 365 | | |
368 | | - | |
369 | | - | |
370 | | - | |
| 366 | + | |
371 | 367 | | |
372 | 368 | | |
373 | 369 | | |
| |||
609 | 605 | | |
610 | 606 | | |
611 | 607 | | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
612 | 614 | | |
613 | 615 | | |
614 | 616 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
519 | 519 | | |
520 | 520 | | |
521 | 521 | | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
522 | 530 | | |
523 | 531 | | |
524 | 532 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
658 | 658 | | |
659 | 659 | | |
660 | 660 | | |
| 661 | + | |
| 662 | + | |
661 | 663 | | |
662 | 664 | | |
663 | 665 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| |||
382 | 383 | | |
383 | 384 | | |
384 | 385 | | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
385 | 391 | | |
386 | 392 | | |
387 | 393 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
84 | 84 | | |
85 | 85 | | |
86 | 86 | | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
87 | 96 | | |
88 | 97 | | |
89 | 98 | | |
| |||
0 commit comments