Skip to content

Commit 1e1df9e

Browse files
committed
[ET-VK] Add mechanism to trigger command buffer re-encode only when necessary
## Context Dynamic shape models currently will require the command buffer to be re-encoded every inference. However, this introduces a significant overhead when running models that require dynamic shapes. The reality is that a command buffer re-encode may not be needed every frame. A command buffer re-encode will only be needed when: 1. Shader dispatch parameters change; i.e. new tensor sizes require a completely different compute shader, require new local work group sizing, or require new work group grid size (i.e. global work group size / local work group size) 2. Push constants containing tensor metadata need to be updated This diff aims to reduce the overhead of triggering tensor shape change by detecting when a command buffer re-encode is actually needed. ## Changes `ComputeGraph`: * Introduce `requires_reencode` flag to `ComputeGraph` to indicate when a command buffer re-encode is needed. * Introduce a `std::set<ValueRef>` tracking which values were updated when propagating tensor sizes * "update" can be one of two things: 1) tensor sizes changed 2) symint value changed `DispatchNode`: * When propagating new tensor sizes, only execute the resize function if any of the values participating in the computation have been updated * Mark `requries_reencode` if any push constants associated with tensor metadata need to be udpated `DynamicDispatchNode`: * Only recompute compute shader dispatch params if any of the values participating in the computation have been updated * Mark `requires_reencode` if 1) a new compute shader is required, 2) local work group size changed, 3) work group grid size changed Differential Revision: [D79813237](https://our.internmc.facebook.com/intern/diff/D79813237/) [ghstack-poisoned]
1 parent 016eece commit 1e1df9e

File tree

9 files changed

+245
-31
lines changed

9 files changed

+245
-31
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -582,13 +582,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
582582
}
583583
}
584584

585-
// propagate_resize() will re-encode the command buffer so that push
586-
// constants are updated and DynamicDispatchNode can update the compute
587-
// shader, global workgroup size, and local workgroup size to perform the
588-
// model inference.
589-
if (should_propagate_resize ||
590-
(compute_graph->graphconfig().expect_dynamic_shapes &&
591-
compute_graph->execute_count() == 0u)) {
585+
if (should_propagate_resize) {
592586
compute_graph->propagate_resize();
593587
}
594588

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,44 @@ utils::StorageType ComputeGraph::suggested_storage_type() {
206206
return utils::kTexture3D;
207207
}
208208

209+
bool ComputeGraph::was_value_updated(const ValueRef value_ref) const {
210+
// Check if this ValueRef itself was updated
211+
if (updated_values_.find(value_ref) != updated_values_.end()) {
212+
return true;
213+
}
214+
215+
// If this is a ValueList, check each ValueRef in the list
216+
if (val_is_value_list(value_ref)) {
217+
const auto& value_list = values_.at(value_ref).toConstValueList();
218+
for (const auto& nested_value_ref : value_list) {
219+
if (was_value_updated(nested_value_ref)) {
220+
return true;
221+
}
222+
}
223+
}
224+
225+
return false;
226+
}
227+
228+
bool ComputeGraph::was_value_ref_updated(const ValueRef value_ref) const {
229+
// Check if this ValueRef itself was updated
230+
if (updated_values_.find(value_ref) != updated_values_.end()) {
231+
return true;
232+
}
233+
234+
// If this is a ValueList, check each ValueRef in the list
235+
if (val_is_value_list(value_ref)) {
236+
const auto& value_list = values_.at(value_ref).toConstValueList();
237+
for (const auto& nested_value_ref : value_list) {
238+
if (was_value_ref_updated(nested_value_ref)) {
239+
return true;
240+
}
241+
}
242+
}
243+
244+
return false;
245+
}
246+
209247
utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
210248
const std::vector<int64_t>& sizes) {
211249
if (config_.enable_memory_layout_override) {
@@ -569,7 +607,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
569607
}
570608

571609
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
572-
get_symint(idx)->set(val);
610+
int32_t cur_val = read_symint(idx);
611+
if (cur_val != val) {
612+
get_symint(idx)->set(val);
613+
// Track that this ValueRef was updated
614+
updated_values_.insert(idx);
615+
}
573616
}
574617

575618
int32_t ComputeGraph::read_symint(const ValueRef idx) {
@@ -921,6 +964,12 @@ void ComputeGraph::execute() {
921964
}
922965

923966
execute_count_++;
967+
968+
// Clear the set of updated values at the end of inference
969+
updated_values_.clear();
970+
971+
// Reset the re-encoding flag at the end of inference
972+
requires_reencode_ = false;
924973
}
925974

926975
void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) {
@@ -938,21 +987,30 @@ void ComputeGraph::resize_input(
938987
const int64_t idx,
939988
const std::vector<int64_t>& new_sizes) {
940989
IOValueRef io_val = inputs_.at(idx);
941-
get_tensor(io_val.value)->virtual_resize(new_sizes);
990+
virtual_resize(io_val.value, new_sizes);
991+
updated_values_.insert(io_val.staging);
942992
}
943993

944994
void ComputeGraph::virtual_resize(
945995
const ValueRef idx,
946996
const std::vector<int64_t>& new_sizes) {
947-
get_tensor(idx)->virtual_resize(new_sizes);
997+
std::vector<int64_t> cur_sizes = sizes_of(idx);
998+
if (cur_sizes != new_sizes) {
999+
get_tensor(idx)->virtual_resize(new_sizes);
1000+
// Track that this ValueRef was updated
1001+
updated_values_.insert(idx);
1002+
}
9481003
}
9491004

9501005
void ComputeGraph::propagate_resize() {
9511006
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
9521007
node->trigger_resize(this);
9531008
}
954-
// Only re-encode on resize if dynamic shapes are expected
955-
if (config_.expect_dynamic_shapes) {
1009+
// A command buffer re-encode will be needed if:
1010+
// 1. Any push constant data (used for tensor metadata) was updated
1011+
// 2. Compute shader dispatch parameters (i.e. compute shader, global and
1012+
// local work group sizes) were updated
1013+
if (requires_reencode_) {
9561014
clear_deferred_cmds();
9571015
}
9581016
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ class ComputeGraph final {
196196
// List of command buffers deferred for submission
197197
std::vector<vkapi::CommandBuffer> deferred_cmd_list_;
198198

199+
// Set to track which ValueRefs were updated during inference
200+
std::unordered_set<ValueRef> updated_values_;
201+
202+
// Flag to indicate if re-encoding is required
203+
bool requires_reencode_ = false;
204+
199205
protected:
200206
size_t values_in_use_ = 0;
201207
size_t execute_count_ = 0;
@@ -419,31 +425,41 @@ class ComputeGraph final {
419425
}
420426

421427
inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const {
422-
return PushConstantDataInfo(
428+
PushConstantDataInfo pc_data = PushConstantDataInfo(
423429
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes);
430+
pc_data.set_value(idx);
431+
return pc_data;
424432
}
425433

426434
inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const {
427-
return PushConstantDataInfo(
435+
PushConstantDataInfo pc_data = PushConstantDataInfo(
428436
values_.at(idx).toConstTensor().get_uniform_data(),
429437
api::kTensorDimOrder);
438+
pc_data.set_value(idx);
439+
return pc_data;
430440
}
431441

432442
inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const {
433-
return PushConstantDataInfo(
443+
PushConstantDataInfo pc_data = PushConstantDataInfo(
434444
values_.at(idx).toConstTensor().get_uniform_data(),
435445
api::kTensorStrides);
446+
pc_data.set_value(idx);
447+
return pc_data;
436448
}
437449

438450
inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const {
439-
return PushConstantDataInfo(
451+
PushConstantDataInfo pc_data = PushConstantDataInfo(
440452
values_.at(idx).toConstTensor().get_uniform_data(),
441453
api::kTensorLogicalLimits);
454+
pc_data.set_value(idx);
455+
return pc_data;
442456
}
443457

444458
inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const {
445-
return PushConstantDataInfo(
459+
PushConstantDataInfo pc_data = PushConstantDataInfo(
446460
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel);
461+
pc_data.set_value(idx);
462+
return pc_data;
447463
}
448464

449465
//
@@ -940,6 +956,19 @@ class ComputeGraph final {
940956

941957
void propagate_resize();
942958

959+
// Check if a specific ValueRef (or ValueList) was updated, with recursive
960+
// handling
961+
bool was_value_updated(const ValueRef value_ref) const;
962+
963+
// Check if a specific ValueRef (or ValueList) was updated, with recursive
964+
// handling
965+
bool was_value_ref_updated(const ValueRef value_ref) const;
966+
967+
// Set the flag to indicate that re-encoding is required
968+
inline void set_requires_reencode() {
969+
requires_reencode_ = true;
970+
}
971+
943972
//
944973
// Miscellaneous Utilities
945974
//

backends/vulkan/runtime/graph/containers/PushConstantData.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/backends/vulkan/runtime/api/api.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14+
1315
namespace vkcompute {
1416

1517
class ComputeGraph;
@@ -33,6 +35,9 @@ class PushConstantDataInfo {
3335
};
3436

3537
Payload payload_;
38+
// The value in a compute graph that this push constant data is associated
39+
// with, if any.
40+
ValueRef value_ = kDummyValueRef;
3641

3742
public:
3843
explicit PushConstantDataInfo(
@@ -60,6 +65,18 @@ class PushConstantDataInfo {
6065
void* dst,
6166
const uint32_t dst_offset,
6267
const uint32_t max_dst_size) const;
68+
69+
inline bool is_tensor_metadata() const {
70+
return tensorUniformData != nullptr;
71+
}
72+
73+
inline void set_value(ValueRef value) {
74+
value_ = value;
75+
}
76+
77+
inline ValueRef value() const {
78+
return value_;
79+
}
6380
};
6481

6582
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,42 @@ void DispatchNode::write_push_constant_data() {
8989
}
9090
}
9191

92+
bool DispatchNode::trigger_resize(ComputeGraph* graph) {
93+
bool any_value_updated = was_any_value_updated(graph);
94+
if (resize_fn_ != nullptr && any_value_updated) {
95+
resize_fn_(graph, args_, resize_args_);
96+
97+
// If this shader uses push constants, and the tensor metadata associated
98+
// with the push constants has changed, then the command buffer needs to be
99+
// re-encoded since push constants cannot be updated.
100+
for (const auto& push_constant : push_constants_) {
101+
if (push_constant.is_tensor_metadata() &&
102+
graph->was_value_ref_updated(push_constant.value())) {
103+
graph->set_requires_reencode();
104+
}
105+
}
106+
}
107+
return any_value_updated;
108+
}
109+
110+
bool DispatchNode::was_any_value_updated(ComputeGraph* graph) const {
111+
// Check all ValueRefs in ArgGroups
112+
for (const auto& arg_group : args_) {
113+
for (const auto& value_ref : arg_group.refs) {
114+
if (graph->was_value_ref_updated(value_ref)) {
115+
return true;
116+
}
117+
}
118+
}
119+
120+
// Check all ValueRefs in resize_args
121+
for (const auto& value_ref : resize_args_) {
122+
if (graph->was_value_ref_updated(value_ref)) {
123+
return true;
124+
}
125+
}
126+
127+
return false;
128+
}
129+
92130
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class DispatchNode : public ExecuteNode {
4444

4545
void encode(ComputeGraph* graph) override;
4646

47+
bool trigger_resize(ComputeGraph* graph) override;
48+
49+
private:
50+
// Helper function to check if any ValueRef was updated
51+
bool was_any_value_updated(ComputeGraph* graph) const;
52+
4753
protected:
4854
vkapi::ShaderInfo shader_;
4955
utils::uvec3 global_workgroup_size_;

0 commit comments

Comments
 (0)