diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 218bd31d1b4..7077a9df59c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -359,15 +359,11 @@ class GraphBuilder { vkFn(*compute_graph_, args); } - // Parse the outputs, which will be mostly tensors. For some reason, - // mutable buffers are shown to be returned in the fx.Graph but do not get - // returned by the delegate; this may be an implementation detail of how the - // executorch emitter handles mutable buffers. + // Parse the outputs, which will be mostly tensors but may contain tensorref + // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_output_tensor(ref); - } + compute_graph_->set_output_value(ref); } if (compute_graph_->graphconfig().enable_querypool) { @@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), args[o]->toTensor().numel()); + } + // TensorRef values represent constant tensors which will not have been + // modified by the graph execution. Therefore, if a constant tensor is + // returned as an output, no action is required. + else if (compute_graph->val_is_tref(oref)) { + continue; } else { VK_THROW( "Could not handle output with type ", diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index b63f89e299d..cb14a41e98a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor( return idx; } +ValueRef ComputeGraph::set_output_value(const ValueRef idx) { + if (values_.at(idx).isTensor()) { + return set_output_tensor(idx); + } + outputs_.push_back({idx, kDummyValueRef}); + return idx; +} + vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( const ValueRef idx) { if (values_.at(idx).isInt()) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index eac632e6d35..78135a434e5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -658,6 +658,8 @@ class ComputeGraph final { ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + ValueRef set_output_value(const ValueRef idx); + template vkapi::BufferBindInfo create_params_buffer(const Block& data) { param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data)); diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d21d33b75da..5bae0475c28 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -20,6 +20,7 @@ from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, + is_mutable_buffer_node, is_param_node, is_symint_node, ) @@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None: "the output node is being serialized before its corresponding " "internal node which is not allowed." ) + # Mutable buffers outputs are not included as an output to the + # delegate call. Skip marking them as an output. + if is_mutable_buffer_node(out_node, self.program): + continue + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node, call_node_debug_hdl: int) -> None: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 5d57ce1e7be..d71c0a35776 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: ) +def is_mutable_buffer_node( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + def is_symint_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a SymInt value