Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
}
}

// propagate_resize() will re-encode the command buffer so that push
// constants are updated and DynamicDispatchNode can update the compute
// shader, global workgroup size, and local workgroup size to perform the
// model inference.
if (should_propagate_resize ||
(compute_graph->graphconfig().expect_dynamic_shapes &&
compute_graph->execute_count() == 0u)) {
if (should_propagate_resize) {
compute_graph->propagate_resize();
}

Expand Down
68 changes: 63 additions & 5 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,44 @@ utils::StorageType ComputeGraph::suggested_storage_type() {
return utils::kTexture3D;
}

bool ComputeGraph::was_value_updated(const ValueRef value_ref) const {
// Check if this ValueRef itself was updated
if (updated_values_.find(value_ref) != updated_values_.end()) {
return true;
}

// If this is a ValueList, check each ValueRef in the list
if (val_is_value_list(value_ref)) {
const auto& value_list = values_.at(value_ref).toConstValueList();
for (const auto& nested_value_ref : value_list) {
if (was_value_updated(nested_value_ref)) {
return true;
}
}
}

return false;
}

bool ComputeGraph::was_value_ref_updated(const ValueRef value_ref) const {
// Check if this ValueRef itself was updated
if (updated_values_.find(value_ref) != updated_values_.end()) {
return true;
}

// If this is a ValueList, check each ValueRef in the list
if (val_is_value_list(value_ref)) {
const auto& value_list = values_.at(value_ref).toConstValueList();
for (const auto& nested_value_ref : value_list) {
if (was_value_ref_updated(nested_value_ref)) {
return true;
}
}
}

return false;
}

utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
const std::vector<int64_t>& sizes) {
if (config_.enable_memory_layout_override) {
Expand Down Expand Up @@ -569,7 +607,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
}

void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
get_symint(idx)->set(val);
int32_t cur_val = read_symint(idx);
if (cur_val != val) {
get_symint(idx)->set(val);
// Track that this ValueRef was updated
updated_values_.insert(idx);
}
}

int32_t ComputeGraph::read_symint(const ValueRef idx) {
Expand Down Expand Up @@ -921,6 +964,12 @@ void ComputeGraph::execute() {
}

execute_count_++;

// Clear the set of updated values at the end of inference
updated_values_.clear();

// Reset the re-encoding flag at the end of inference
requires_reencode_ = false;
}

void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) {
Expand All @@ -938,21 +987,30 @@ void ComputeGraph::resize_input(
const int64_t idx,
const std::vector<int64_t>& new_sizes) {
IOValueRef io_val = inputs_.at(idx);
get_tensor(io_val.value)->virtual_resize(new_sizes);
virtual_resize(io_val.value, new_sizes);
updated_values_.insert(io_val.staging);
}

void ComputeGraph::virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
std::vector<int64_t> cur_sizes = sizes_of(idx);
if (cur_sizes != new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
// Track that this ValueRef was updated
updated_values_.insert(idx);
}
}

void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
// Only re-encode on resize if dynamic shapes are expected
if (config_.expect_dynamic_shapes) {
// A command buffer re-encode will be needed if:
// 1. Any push constant data (used for tensor metadata) was updated
// 2. Compute shader dispatch parameters (i.e. compute shader, global and
// local work group sizes) were updated
if (requires_reencode_) {
clear_deferred_cmds();
}
}
Expand Down
39 changes: 34 additions & 5 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ class ComputeGraph final {
// List of command buffers deferred for submission
std::vector<vkapi::CommandBuffer> deferred_cmd_list_;

// Set to track which ValueRefs were updated during inference
std::unordered_set<ValueRef> updated_values_;

// Flag to indicate if re-encoding is required
bool requires_reencode_ = false;

protected:
size_t values_in_use_ = 0;
size_t execute_count_ = 0;
Expand Down Expand Up @@ -419,31 +425,41 @@ class ComputeGraph final {
}

inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorDimOrder);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorStrides);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorLogicalLimits);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel);
pc_data.set_value(idx);
return pc_data;
}

//
Expand Down Expand Up @@ -940,6 +956,19 @@ class ComputeGraph final {

void propagate_resize();

// Check if a specific ValueRef (or ValueList) was updated, with recursive
// handling
bool was_value_updated(const ValueRef value_ref) const;

// Check if a specific ValueRef (or ValueList) was updated, with recursive
// handling
bool was_value_ref_updated(const ValueRef value_ref) const;

// Set the flag to indicate that re-encoding is required
inline void set_requires_reencode() {
requires_reencode_ = true;
}

//
// Miscellaneous Utilities
//
Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/containers/PushConstantData.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

namespace vkcompute {

class ComputeGraph;
Expand All @@ -33,6 +35,9 @@ class PushConstantDataInfo {
};

Payload payload_;
// The value in a compute graph that this push constant data is associated
// with, if any.
ValueRef value_ = kDummyValueRef;

public:
explicit PushConstantDataInfo(
Expand Down Expand Up @@ -60,6 +65,18 @@ class PushConstantDataInfo {
void* dst,
const uint32_t dst_offset,
const uint32_t max_dst_size) const;

inline bool is_tensor_metadata() const {
return tensorUniformData != nullptr;
}

inline void set_value(ValueRef value) {
value_ = value;
}

inline ValueRef value() const {
return value_;
}
};

} // namespace vkcompute
38 changes: 38 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,42 @@ void DispatchNode::write_push_constant_data() {
}
}

bool DispatchNode::trigger_resize(ComputeGraph* graph) {
bool any_value_updated = was_any_value_updated(graph);
if (resize_fn_ != nullptr && any_value_updated) {
resize_fn_(graph, args_, resize_args_);

// If this shader uses push constants, and the tensor metadata associated
// with the push constants has changed, then the command buffer needs to be
// re-encoded since push constants cannot be updated.
for (const auto& push_constant : push_constants_) {
if (push_constant.is_tensor_metadata() &&
graph->was_value_ref_updated(push_constant.value())) {
graph->set_requires_reencode();
}
}
}
return any_value_updated;
}

bool DispatchNode::was_any_value_updated(ComputeGraph* graph) const {
// Check all ValueRefs in ArgGroups
for (const auto& arg_group : args_) {
for (const auto& value_ref : arg_group.refs) {
if (graph->was_value_ref_updated(value_ref)) {
return true;
}
}
}

// Check all ValueRefs in resize_args
for (const auto& value_ref : resize_args_) {
if (graph->was_value_ref_updated(value_ref)) {
return true;
}
}

return false;
}

} // namespace vkcompute
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class DispatchNode : public ExecuteNode {

void encode(ComputeGraph* graph) override;

bool trigger_resize(ComputeGraph* graph) override;

private:
// Helper function to check if any ValueRef was updated
bool was_any_value_updated(ComputeGraph* graph) const;

protected:
vkapi::ShaderInfo shader_;
utils::uvec3 global_workgroup_size_;
Expand Down
Loading
Loading