Skip to content

Commit 951aef9

Browse files
committed
[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers
## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) ghstack-source-id: 292684341 Pull Request resolved: #11983
1 parent 910cc4e commit 951aef9

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,11 @@ class GraphBuilder {
359359
vkFn(*compute_graph_, args);
360360
}
361361

362-
// Parse the outputs, which will be mostly tensors. For some reason,
363-
// mutable buffers are shown to be returned in the fx.Graph but do not get
364-
// returned by the delegate; this may be an implementation detail of how the
365-
// executorch emitter handles mutable buffers.
362+
// Parse the outputs, which will be mostly tensors but may contain tensorref
363+
// values as well if the source graph returns parameter nodes.
366364
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
367365
const ValueRef ref = get_fb_id_valueref(fb_id);
368-
if (compute_graph_->val_is_tensor(ref)) {
369-
compute_graph_->set_output_tensor(ref);
370-
}
366+
compute_graph_->set_output_value(ref);
371367
}
372368

373369
if (compute_graph_->graphconfig().enable_querypool) {
@@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
609605
compute_graph->outputs()[i].staging,
610606
args[o]->toTensor().mutable_data_ptr(),
611607
args[o]->toTensor().numel());
608+
}
609+
// TensorRef values represent constant tensors which will not have been
610+
// modified by the graph execution. Therefore, if a constant tensor is
611+
// returned as an output, no action is required.
612+
else if (compute_graph->val_is_tref(oref)) {
613+
continue;
612614
} else {
613615
VK_THROW(
614616
"Could not handle output with type ",

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor(
519519
return idx;
520520
}
521521

522+
ValueRef ComputeGraph::set_output_value(const ValueRef idx) {
523+
if (values_.at(idx).isTensor()) {
524+
return set_output_tensor(idx);
525+
}
526+
outputs_.push_back({idx, kDummyValueRef});
527+
return idx;
528+
}
529+
522530
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
523531
const ValueRef idx) {
524532
if (values_.at(idx).isInt()) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ class ComputeGraph final {
658658
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
659659
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
660660

661+
ValueRef set_output_value(const ValueRef idx);
662+
661663
template <typename Block>
662664
vkapi::BufferBindInfo create_params_buffer(const Block& data) {
663665
param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.vulkan.utils import (
2121
is_constant,
2222
is_get_attr_node,
23+
is_mutable_buffer_node,
2324
is_param_node,
2425
is_symint_node,
2526
)
@@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None:
382383
"the output node is being serialized before its corresponding "
383384
"internal node which is not allowed."
384385
)
386+
# Mutable buffers outputs are not included as an output to the
387+
# delegate call. Skip marking them as an output.
388+
if is_mutable_buffer_node(out_node, self.program):
389+
continue
390+
385391
self.output_ids.append(self.node_to_value_ids[out_node])
386392

387393
def process_node(self, node: Node, call_node_debug_hdl: int) -> None:

backends/vulkan/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
8484
)
8585

8686

87+
def is_mutable_buffer_node(
88+
node: torch.fx.Node, exported_program: ExportedProgram
89+
) -> bool:
90+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
91+
return False
92+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
93+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
94+
95+
8796
def is_symint_node(node: torch.fx.Node) -> bool:
8897
"""
8998
Returns true if the given node produces a SymInt value

0 commit comments

Comments
 (0)