From adf0cfa5195924f476c580a089c833fc444adf3b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 25 Jun 2025 13:08:47 -0700 Subject: [PATCH 1/3] [ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers ## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 16 +++++++++------- backends/vulkan/runtime/graph/ComputeGraph.cpp | 8 ++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 2 ++ .../vulkan/serialization/vulkan_graph_builder.py | 6 ++++++ backends/vulkan/utils.py | 9 +++++++++ 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 218bd31d1b4..7077a9df59c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -359,15 +359,11 @@ class GraphBuilder { vkFn(*compute_graph_, args); } - // Parse the outputs, which will be mostly tensors. For some reason, - // mutable buffers are shown to be returned in the fx.Graph but do not get - // returned by the delegate; this may be an implementation detail of how the - // executorch emitter handles mutable buffers. + // Parse the outputs, which will be mostly tensors but may contain tensorref + // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_output_tensor(ref); - } + compute_graph_->set_output_value(ref); } if (compute_graph_->graphconfig().enable_querypool) { @@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), args[o]->toTensor().numel()); + } + // TensorRef values represent constant tensors which will not have been + // modified by the graph execution. Therefore, if a constant tensor is + // returned as an output, no action is required. + else if (compute_graph->val_is_tref(oref)) { + continue; } else { VK_THROW( "Could not handle output with type ", diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b63f89e299d..cb14a41e98a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor( return idx; } +ValueRef ComputeGraph::set_output_value(const ValueRef idx) { + if (values_.at(idx).isTensor()) { + return set_output_tensor(idx); + } + outputs_.push_back({idx, kDummyValueRef}); + return idx; +} + vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( const ValueRef idx) { if (values_.at(idx).isInt()) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index eac632e6d35..78135a434e5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -658,6 +658,8 @@ class ComputeGraph final { ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + ValueRef set_output_value(const ValueRef idx); + template vkapi::BufferBindInfo create_params_buffer(const Block& data) { param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data)); diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d21d33b75da..5bae0475c28 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -20,6 +20,7 @@ from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, + is_mutable_buffer_node, is_param_node, is_symint_node, ) @@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None: "the output node is being serialized before its corresponding " "internal node which is not allowed." ) + # Mutable buffers outputs are not included as an output to the + # delegate call. Skip marking them as an output. + if is_mutable_buffer_node(out_node, self.program): + continue + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node, call_node_debug_hdl: int) -> None: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 5d57ce1e7be..d71c0a35776 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: ) +def is_mutable_buffer_node( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + def is_symint_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a SymInt value From 8ae60e429a4f5322cc0e6166b2ae805af20dffd8 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 25 Jun 2025 13:59:56 -0700 Subject: [PATCH 2/3] Update on "[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers" ## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) [ghstack-poisoned] From 9882756a187627c8595875a540f07be42bc4b16e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 25 Jun 2025 14:54:31 -0700 Subject: [PATCH 3/3] Update on "[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers" ## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) [ghstack-poisoned]