Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,11 @@ class GraphBuilder {
vkFn(*compute_graph_, args);
}

// Parse the outputs, which will be mostly tensors. For some reason,
// mutable buffers are shown to be returned in the fx.Graph but do not get
// returned by the delegate; this may be an implementation detail of how the
// executorch emitter handles mutable buffers.
// Parse the outputs, which will be mostly tensors but may contain tensorref
// values as well if the source graph returns parameter nodes.
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
const ValueRef ref = get_fb_id_valueref(fb_id);
if (compute_graph_->val_is_tensor(ref)) {
compute_graph_->set_output_tensor(ref);
}
compute_graph_->set_output_value(ref);
}

if (compute_graph_->graphconfig().enable_querypool) {
Expand Down Expand Up @@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
compute_graph->outputs()[i].staging,
args[o]->toTensor().mutable_data_ptr(),
args[o]->toTensor().numel());
}
// TensorRef values represent constant tensors which will not have been
// modified by the graph execution. Therefore, if a constant tensor is
// returned as an output, no action is required.
else if (compute_graph->val_is_tref(oref)) {
continue;
} else {
VK_THROW(
"Could not handle output with type ",
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor(
return idx;
}

ValueRef ComputeGraph::set_output_value(const ValueRef idx) {
if (values_.at(idx).isTensor()) {
return set_output_tensor(idx);
}
outputs_.push_back({idx, kDummyValueRef});
return idx;
}

vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
const ValueRef idx) {
if (values_.at(idx).isInt()) {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ class ComputeGraph final {
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);

ValueRef set_output_value(const ValueRef idx);

template <typename Block>
vkapi::BufferBindInfo create_params_buffer(const Block& data) {
param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from executorch.backends.vulkan.utils import (
is_constant,
is_get_attr_node,
is_mutable_buffer_node,
is_param_node,
is_symint_node,
)
Expand Down Expand Up @@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None:
"the output node is being serialized before its corresponding "
"internal node which is not allowed."
)
# Mutable buffers outputs are not included as an output to the
# delegate call. Skip marking them as an output.
if is_mutable_buffer_node(out_node, self.program):
continue

self.output_ids.append(self.node_to_value_ids[out_node])

def process_node(self, node: Node, call_node_debug_hdl: int) -> None:
Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
)


def is_mutable_buffer_node(
node: torch.fx.Node, exported_program: ExportedProgram
) -> bool:
if node.target not in exported_program.graph_signature.inputs_to_buffers:
return False
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
return buf in exported_program.graph_signature.buffers_to_mutate.values()


def is_symint_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a SymInt value
Expand Down
Loading