Skip to content

Commit 3254ddf

Browse files
authored
[ET-VK] Add mechanism to trigger command buffer re-encode only when necessary (#13379)
## Context Dynamic shape models currently will require the command buffer to be re-encoded every inference. However, this introduces a significant overhead when running models that require dynamic shapes. The reality is that a command buffer re-encode may not be needed every frame. A command buffer re-encode will only be needed when: 1. Shader dispatch parameters change; i.e. new tensor sizes require a completely different compute shader, require new local work group sizing, or require new work group grid size (i.e. global work group size / local work group size) 2. Push constants containing tensor metadata need to be updated This diff aims to reduce the overhead of triggering tensor shape change by detecting when a command buffer re-encode is actually needed. ## Changes `ComputeGraph`: * Introduce `requires_reencode` flag to `ComputeGraph` to indicate when a command buffer re-encode is needed. * Introduce a `std::set<ValueRef>` tracking which values were updated when propagating tensor sizes * "update" can be one of two things: 1) tensor sizes changed 2) symint value changed `DispatchNode`: * When propagating new tensor sizes, only execute the resize function if any of the values participating in the computation have been updated * Mark `requries_reencode` if any push constants associated with tensor metadata need to be udpated `DynamicDispatchNode`: * Only recompute compute shader dispatch params if any of the values participating in the computation have been updated * Mark `requires_reencode` if 1) a new compute shader is required, 2) local work group size changed, 3) work group grid size changed Differential Revision: [D79813237](https://our.internmc.facebook.com/intern/diff/D79813237/)
1 parent f95a3f7 commit 3254ddf

File tree

11 files changed

+245
-28
lines changed

11 files changed

+245
-28
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -583,13 +583,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
583583
}
584584
}
585585

586-
// propagate_resize() will re-encode the command buffer so that push
587-
// constants are updated and DynamicDispatchNode can update the compute
588-
// shader, global workgroup size, and local workgroup size to perform the
589-
// model inference.
590-
if (should_propagate_resize ||
591-
(compute_graph->graphconfig().expect_dynamic_shapes &&
592-
compute_graph->execute_count() == 0u)) {
586+
if (should_propagate_resize) {
593587
compute_graph->propagate_resize();
594588
}
595589

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,29 @@ utils::StorageType ComputeGraph::suggested_storage_type() {
206206
return utils::kTexture3D;
207207
}
208208

209+
bool ComputeGraph::was_value_updated(const ValueRef idx) const noexcept {
210+
if (!is_valid_value_idx(idx)) {
211+
return false;
212+
}
213+
214+
// Check if this ValueRef itself was updated
215+
if (updated_values_.find(idx) != updated_values_.end()) {
216+
return true;
217+
}
218+
219+
// If this is a ValueList, check each ValueRef in the list
220+
if (val_is_value_list(idx)) {
221+
const auto& value_list = values_.at(idx).toConstValueList();
222+
for (const auto& nested_idx : value_list) {
223+
if (was_value_updated(nested_idx)) {
224+
return true;
225+
}
226+
}
227+
}
228+
229+
return false;
230+
}
231+
209232
utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
210233
const std::vector<int64_t>& sizes) {
211234
if (config_.enable_memory_layout_override) {
@@ -236,6 +259,10 @@ void ComputeGraph::check_no_active_value_ptrs() {
236259
"invalidated.");
237260
}
238261

262+
bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept {
263+
return idx >= 0 && idx < static_cast<int>(values_.size());
264+
}
265+
239266
std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
240267
const Value& val = values_.at(idx);
241268
if (val.isTensor()) {
@@ -569,7 +596,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
569596
}
570597

571598
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
572-
get_symint(idx)->set(val);
599+
int32_t cur_val = read_symint(idx);
600+
if (cur_val != val) {
601+
get_symint(idx)->set(val);
602+
// Track that this ValueRef was updated
603+
updated_values_.insert(idx);
604+
}
573605
}
574606

575607
int32_t ComputeGraph::read_symint(const ValueRef idx) {
@@ -951,6 +983,12 @@ void ComputeGraph::execute() {
951983
}
952984

953985
execute_count_++;
986+
987+
// Clear the set of updated values at the end of inference
988+
updated_values_.clear();
989+
990+
// Reset the re-encoding flag at the end of inference
991+
requires_reencode_ = false;
954992
}
955993

956994
void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) {
@@ -968,21 +1006,30 @@ void ComputeGraph::resize_input(
9681006
const int64_t idx,
9691007
const std::vector<int64_t>& new_sizes) {
9701008
IOValueRef io_val = inputs_.at(idx);
971-
get_tensor(io_val.value)->virtual_resize(new_sizes);
1009+
virtual_resize(io_val.value, new_sizes);
1010+
updated_values_.insert(io_val.staging);
9721011
}
9731012

9741013
void ComputeGraph::virtual_resize(
9751014
const ValueRef idx,
9761015
const std::vector<int64_t>& new_sizes) {
977-
get_tensor(idx)->virtual_resize(new_sizes);
1016+
std::vector<int64_t> cur_sizes = sizes_of(idx);
1017+
if (cur_sizes != new_sizes) {
1018+
get_tensor(idx)->virtual_resize(new_sizes);
1019+
// Track that this ValueRef was updated
1020+
updated_values_.insert(idx);
1021+
}
9781022
}
9791023

9801024
void ComputeGraph::propagate_resize() {
9811025
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
9821026
node->trigger_resize(this);
9831027
}
984-
// Only re-encode on resize if dynamic shapes are expected
985-
if (config_.expect_dynamic_shapes) {
1028+
// A command buffer re-encode will be needed if:
1029+
// 1. Any push constant data (used for tensor metadata) was updated
1030+
// 2. Compute shader dispatch parameters (i.e. compute shader, global and
1031+
// local work group sizes) were updated
1032+
if (requires_reencode_) {
9861033
clear_deferred_cmds();
9871034
}
9881035
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ class ComputeGraph final {
196196
// List of command buffers deferred for submission
197197
std::vector<vkapi::CommandBuffer> deferred_cmd_list_;
198198

199+
// Set to track which ValueRefs were updated during inference
200+
std::unordered_set<ValueRef> updated_values_;
201+
202+
// Flag to indicate if re-encoding is required
203+
bool requires_reencode_ = false;
204+
199205
protected:
200206
size_t values_in_use_ = 0;
201207
size_t execute_count_ = 0;
@@ -244,6 +250,9 @@ class ComputeGraph final {
244250
return config_;
245251
}
246252

253+
// Check if the ComputeGraph has a value at the specified index
254+
bool is_valid_value_idx(const ValueRef idx) const noexcept;
255+
247256
//
248257
// Value Extraction
249258
//
@@ -427,31 +436,41 @@ class ComputeGraph final {
427436
}
428437

429438
inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const {
430-
return PushConstantDataInfo(
439+
PushConstantDataInfo pc_data = PushConstantDataInfo(
431440
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes);
441+
pc_data.set_value(idx);
442+
return pc_data;
432443
}
433444

434445
inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const {
435-
return PushConstantDataInfo(
446+
PushConstantDataInfo pc_data = PushConstantDataInfo(
436447
values_.at(idx).toConstTensor().get_uniform_data(),
437448
api::kTensorDimOrder);
449+
pc_data.set_value(idx);
450+
return pc_data;
438451
}
439452

440453
inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const {
441-
return PushConstantDataInfo(
454+
PushConstantDataInfo pc_data = PushConstantDataInfo(
442455
values_.at(idx).toConstTensor().get_uniform_data(),
443456
api::kTensorStrides);
457+
pc_data.set_value(idx);
458+
return pc_data;
444459
}
445460

446461
inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const {
447-
return PushConstantDataInfo(
462+
PushConstantDataInfo pc_data = PushConstantDataInfo(
448463
values_.at(idx).toConstTensor().get_uniform_data(),
449464
api::kTensorLogicalLimits);
465+
pc_data.set_value(idx);
466+
return pc_data;
450467
}
451468

452469
inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const {
453-
return PushConstantDataInfo(
470+
PushConstantDataInfo pc_data = PushConstantDataInfo(
454471
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel);
472+
pc_data.set_value(idx);
473+
return pc_data;
455474
}
456475

457476
//
@@ -948,6 +967,15 @@ class ComputeGraph final {
948967

949968
void propagate_resize();
950969

970+
// Check if a specific ValueRef (or ValueList) was updated, with recursive
971+
// handling
972+
bool was_value_updated(const ValueRef idx) const noexcept;
973+
974+
// Set the flag to indicate that re-encoding is required
975+
inline void set_requires_reencode() noexcept {
976+
requires_reencode_ = true;
977+
}
978+
951979
//
952980
// Miscellaneous Utilities
953981
//

backends/vulkan/runtime/graph/containers/PushConstantData.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/backends/vulkan/runtime/api/api.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14+
1315
namespace vkcompute {
1416

1517
class ComputeGraph;
@@ -33,6 +35,9 @@ class PushConstantDataInfo {
3335
};
3436

3537
Payload payload_;
38+
// The value in a compute graph that this push constant data is associated
39+
// with, if any.
40+
ValueRef value_ = kDummyValueRef;
3641

3742
public:
3843
explicit PushConstantDataInfo(
@@ -60,6 +65,18 @@ class PushConstantDataInfo {
6065
void* dst,
6166
const uint32_t dst_offset,
6267
const uint32_t max_dst_size) const;
68+
69+
inline bool is_tensor_metadata() const noexcept {
70+
return tensorUniformData != nullptr;
71+
}
72+
73+
inline void set_value(ValueRef value) noexcept {
74+
value_ = value;
75+
}
76+
77+
inline ValueRef value() const noexcept {
78+
return value_;
79+
}
6380
};
6481

6582
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,21 @@ void DispatchNode::write_push_constant_data() {
8989
}
9090
}
9191

92+
bool DispatchNode::trigger_resize(ComputeGraph* graph) {
93+
const bool any_arg_updated = ExecuteNode::trigger_resize(graph);
94+
95+
if (any_arg_updated) {
96+
// If this shader uses push constants, and the tensor metadata associated
97+
// with the push constants has changed, then the command buffer needs to be
98+
// re-encoded since push constants cannot be updated.
99+
for (const auto& push_constant : push_constants_) {
100+
if (push_constant.is_tensor_metadata() &&
101+
graph->was_value_updated(push_constant.value())) {
102+
graph->set_requires_reencode();
103+
}
104+
}
105+
}
106+
return any_arg_updated;
107+
}
108+
92109
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class DispatchNode : public ExecuteNode {
4444

4545
void encode(ComputeGraph* graph) override;
4646

47+
bool trigger_resize(ComputeGraph* graph) override;
48+
4749
protected:
4850
vkapi::ShaderInfo shader_;
4951
utils::uvec3 global_workgroup_size_;

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ DynamicDispatchNode::DynamicDispatchNode(
4141
pick_global_wg_fn(&graph, shader_, args, resize_args);
4242
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
4343
&graph, shader_, global_workgroup_size_, args, resize_args));
44+
45+
// Calculate dispatch grid similar to Context.cpp register_shader_dispatch
46+
wg_dispatch_grid_ = {
47+
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
48+
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
49+
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};
4450
}
4551

4652
DynamicDispatchNode::DynamicDispatchNode(
@@ -72,21 +78,74 @@ DynamicDispatchNode::DynamicDispatchNode(
7278
pick_global_wg_fn(&graph, shader_, args, resize_args);
7379
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
7480
&graph, shader_, global_workgroup_size_, args, resize_args));
81+
// Calculate the work group grid that will be dispatched
82+
wg_dispatch_grid_ = {
83+
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
84+
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
85+
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};
7586
}
7687

77-
void DynamicDispatchNode::encode(ComputeGraph* graph) {
88+
bool DynamicDispatchNode::trigger_resize(ComputeGraph* graph) {
89+
// DispatchNode::trigger_resize() will return true if any of the values
90+
// participating in this operation were updated.
91+
const bool any_arg_updated = DispatchNode::trigger_resize(graph);
92+
// Only re-compute the shader, global workgroup size, and local workgroup size
93+
// if any of the values participating in this operation were updated.
94+
// Otherwise, assume that these will not have changed.
95+
if (!any_arg_updated) {
96+
return false;
97+
}
98+
99+
// Indicates if the shader dispatch should be changed since the last time the
100+
// command buffer was encoded.
101+
bool dispatch_params_changed = false;
102+
78103
if (pick_shader_fn_) {
79-
shader_ = pick_shader_fn_(graph, args_, resize_args_);
104+
vkapi::ShaderInfo new_shader = pick_shader_fn_(graph, args_, resize_args_);
105+
// Compare shader kernel names as a proxy for shader equality
106+
if (shader_.kernel_name != new_shader.kernel_name) {
107+
shader_ = new_shader;
108+
dispatch_params_changed = true;
109+
}
80110
}
81111
if (pick_global_wg_fn_) {
112+
// Note that if global workgroup size changes, then the dispatch params
113+
// may not actually be different. The actual value to check is the
114+
// work group grid size that will be dispatched, which is calculated
115+
// below.
82116
global_workgroup_size_ =
83117
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
84118
}
85119
if (pick_local_wg_fn_) {
86-
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
87-
graph, shader_, global_workgroup_size_, args_, resize_args_));
120+
utils::uvec3 new_local_wg_uvec3 = pick_local_wg_fn_(
121+
graph, shader_, global_workgroup_size_, args_, resize_args_);
122+
utils::WorkgroupSize new_local_wg =
123+
utils::WorkgroupSize(new_local_wg_uvec3);
124+
if (local_workgroup_size_ != new_local_wg) {
125+
local_workgroup_size_ = new_local_wg;
126+
dispatch_params_changed = true;
127+
}
128+
}
129+
130+
// Always recompute the new dispatch grid and check if it's different
131+
utils::uvec3 new_wg_dispatch_grid = {
132+
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
133+
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
134+
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};
135+
136+
// Check if the new dispatch grid is different from the old one
137+
if (wg_dispatch_grid_ != new_wg_dispatch_grid) {
138+
dispatch_params_changed = true;
88139
}
89-
DispatchNode::encode(graph);
140+
wg_dispatch_grid_ = new_wg_dispatch_grid;
141+
142+
// If any of the dispatch params have changed, then the command buffer must
143+
// be re-encoded.
144+
if (dispatch_params_changed) {
145+
graph->set_requires_reencode();
146+
}
147+
148+
return true;
90149
}
91150

92151
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,15 @@ class DynamicDispatchNode final : public DispatchNode {
6868

6969
~DynamicDispatchNode() override = default;
7070

71-
void encode(ComputeGraph* graph) override;
71+
bool trigger_resize(ComputeGraph* graph) override;
7272

7373
protected:
7474
const PickShaderFn pick_shader_fn_;
7575
const PickGlobalFn pick_global_wg_fn_;
7676
const PickLocalFn pick_local_wg_fn_;
7777

78+
utils::uvec3 wg_dispatch_grid_{1u, 1u, 1u};
79+
7880
public:
7981
operator bool() const {
8082
return shader_;

0 commit comments

Comments
 (0)