Skip to content

Commit a016a3d

Browse files
authored
Merge branch 'main' into export-D75826474
2 parents e21336b + d7699d6 commit a016a3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1172
-1016
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
499499
compute_graph->encode_prepack();
500500
compute_graph->prepack();
501501

502+
// TODO(ssjia): remove this once we can batch compile compute pipelines
503+
// during prepare().
502504
compute_graph->encode_execute();
503505

504506
return Error::Ok;
@@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
567569
}
568570
}
569571

572+
// propagate_resize() will re-encode the command buffer so that push
573+
// constants are updated and DynamicDispatchNode can update the compute
574+
// shader, global workgroup size, and local workgroup size to perform the
575+
// model inference.
570576
if (should_propagate_resize) {
571577
compute_graph->propagate_resize();
572578
}
579+
573580
compute_graph->execute();
574581

575582
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
492492
const ValueRef idx) {
493493
if (values_.at(idx).isInt()) {
494494
const int32_t val = extract_scalar<int32_t>(idx);
495-
create_params_buffer(val);
495+
return create_params_buffer(val);
496496
} else if (values_.at(idx).isSymInt()) {
497497
SymIntPtr symint = get_symint(idx);
498498
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
499499
}
500500
VK_THROW("Cannot create a int param buffer for the given value");
501501
}
502502

503+
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
504+
const ValueRef idx,
505+
const int32_t default_val) {
506+
if (values_.at(idx).isNone()) {
507+
return create_params_buffer(default_val);
508+
} else {
509+
return get_or_create_int_param_buffer(idx);
510+
}
511+
}
512+
503513
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
504514
get_symint(idx)->set(val);
505515
}
@@ -678,11 +688,12 @@ void ComputeGraph::encode_execute() {
678688
}
679689
}
680690

681-
void ComputeGraph::execute() const {
691+
void ComputeGraph::execute() {
682692
vkapi::VulkanFence fence = context_->fences().get_fence();
683693
context_->submit_cmd_to_gpu(fence.get_submit_handle());
684694
fence.wait();
685695
context_->fences().return_fence(fence);
696+
execute_count_++;
686697
}
687698

688699
void ComputeGraph::resize_input(
@@ -692,10 +703,17 @@ void ComputeGraph::resize_input(
692703
get_tensor(io_val.value)->virtual_resize(new_sizes);
693704
}
694705

706+
void ComputeGraph::virtual_resize(
707+
const ValueRef idx,
708+
const std::vector<int64_t>& new_sizes) {
709+
get_tensor(idx)->virtual_resize(new_sizes);
710+
}
711+
695712
void ComputeGraph::propagate_resize() {
696713
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
697714
node->trigger_resize(this);
698715
}
716+
encode_execute();
699717
}
700718

701719
} // namespace vkcompute

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ComputeGraph final {
187187

188188
protected:
189189
size_t values_in_use_ = 0;
190+
size_t execute_count_ = 0;
190191

191192
public:
192193
//
@@ -397,6 +398,19 @@ class ComputeGraph final {
397398
std::optional<T> extract_optional_scalar(const ValueRef idx) {
398399
if (val_is_none(idx)) {
399400
return ::std::nullopt;
401+
} else if (val_is_symint(idx)) {
402+
return utils::safe_downcast<T>(read_symint(idx));
403+
} else {
404+
return extract_scalar<T>(idx);
405+
}
406+
}
407+
408+
template <typename T>
409+
T extract_optional_scalar(const ValueRef idx, const T default_val) {
410+
if (val_is_none(idx)) {
411+
return default_val;
412+
} else if (val_is_symint(idx)) {
413+
return utils::safe_downcast<T>(read_symint(idx));
400414
} else {
401415
return extract_scalar<T>(idx);
402416
}
@@ -608,6 +622,10 @@ class ComputeGraph final {
608622
*/
609623
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
610624

625+
vkapi::BufferBindInfo get_or_create_int_param_buffer(
626+
const ValueRef idx,
627+
const int32_t default_value);
628+
611629
void set_symint(const ValueRef idx, const int32_t val);
612630

613631
int32_t read_symint(const ValueRef idx);
@@ -745,13 +763,16 @@ class ComputeGraph final {
745763
//
746764

747765
void encode_execute();
748-
void execute() const;
766+
void execute();
749767

750768
//
751769
// Dynamic Shape support
752770
//
753771

754772
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
773+
void virtual_resize(
774+
const ValueRef idx,
775+
const std::vector<int64_t>& new_sizes);
755776
void propagate_resize();
756777

757778
//
@@ -762,6 +783,10 @@ class ComputeGraph final {
762783
return context_->adapter_ptr()->supports_int16_shader_types();
763784
}
764785

786+
inline size_t execute_count() const {
787+
return execute_count_;
788+
}
789+
765790
/*
766791
* Check whether the GPU supports 8 bit buffers.
767792
*/

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
4646

4747
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4848

49-
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
50-
uint32_t push_constants_offset = 0;
51-
52-
for (const auto& push_constant : push_constants_) {
53-
push_constants_offset += push_constant.write(
54-
push_constants_data.data(),
55-
push_constants_offset,
56-
kMaxPushConstantSize);
57-
}
49+
write_push_constant_data();
5850

5951
context->report_shader_dispatch_start(
6052
shader_.kernel_name,
@@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
6355
node_id_);
6456

6557
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
66-
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
58+
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);
6759

6860
uint32_t idx = 0;
6961
idx = bind_values_to_descriptor_set(
@@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
7668
pipeline_barrier,
7769
shader_,
7870
global_workgroup_size_,
79-
push_constants_data.data(),
80-
push_constants_offset);
71+
push_constants_data_.data(),
72+
push_constants_offset_);
8173

8274
context->report_shader_dispatch_end();
8375
}
8476

77+
void DispatchNode::write_push_constant_data() {
78+
push_constants_offset_ = 0;
79+
for (const auto& push_constant : push_constants_) {
80+
push_constants_offset_ += push_constant.write(
81+
push_constants_data_.data(),
82+
push_constants_offset_,
83+
kMaxPushConstantSize);
84+
}
85+
}
86+
8587
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
5050
const vkapi::SpecVarList spec_vars_;
5151
const std::vector<PushConstantDataInfo> push_constants_;
5252

53+
// For push constants
54+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
55+
uint32_t push_constants_offset_ = 0;
56+
57+
void write_push_constant_data();
58+
5359
public:
5460
operator bool() const {
5561
return shader_;

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
2525
const ResizeFunction& resize_fn)
2626
: DispatchNode(
2727
graph,
28-
pick_shader_fn(&graph, args, resize_args),
29-
pick_global_wg_fn(&graph, args, resize_args),
30-
pick_local_wg_fn(&graph, args, resize_args),
28+
vkapi::ShaderInfo(),
29+
{1u, 1u, 1u},
30+
{1u, 1u, 1u},
3131
args,
3232
params,
3333
push_constants,
@@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode(
3636
resize_fn),
3737
pick_shader_fn_(pick_shader_fn),
3838
pick_global_wg_fn_(pick_global_wg_fn),
39+
pick_local_wg_fn_(pick_local_wg_fn) {
40+
shader_ = pick_shader_fn(&graph, args, resize_args);
41+
global_workgroup_size_ =
42+
pick_global_wg_fn(&graph, shader_, args, resize_args);
43+
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
44+
&graph, shader_, global_workgroup_size_, args, resize_args));
45+
}
46+
47+
DynamicDispatchNode::DynamicDispatchNode(
48+
ComputeGraph& graph,
49+
const vkapi::ShaderInfo& shader,
50+
const PickGlobalFn& pick_global_wg_fn,
51+
const PickLocalFn& pick_local_wg_fn,
52+
const std::vector<ArgGroup>& args,
53+
const vkapi::ParamsBindList& params,
54+
const std::vector<PushConstantDataInfo>& push_constants,
55+
const vkapi::SpecVarList& spec_vars,
56+
const std::vector<ValueRef>& resize_args,
57+
const ResizeFunction& resize_fn)
58+
: DispatchNode(
59+
graph,
60+
shader,
61+
pick_global_wg_fn(&graph, shader, args, resize_args),
62+
pick_local_wg_fn(
63+
&graph,
64+
shader,
65+
pick_global_wg_fn(&graph, shader, args, resize_args),
66+
args,
67+
resize_args),
68+
args,
69+
params,
70+
push_constants,
71+
spec_vars,
72+
resize_args,
73+
resize_fn),
74+
pick_shader_fn_{nullptr},
75+
pick_global_wg_fn_(pick_global_wg_fn),
3976
pick_local_wg_fn_(pick_local_wg_fn) {}
4077

4178
void DynamicDispatchNode::encode(ComputeGraph* graph) {
42-
shader_ = pick_shader_fn_(graph, args_, resize_args_);
43-
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44-
local_workgroup_size_ =
45-
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
79+
if (pick_shader_fn_) {
80+
shader_ = pick_shader_fn_(graph, args_, resize_args_);
81+
}
82+
if (pick_global_wg_fn_) {
83+
global_workgroup_size_ =
84+
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
85+
}
86+
if (pick_local_wg_fn_) {
87+
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
88+
graph, shader_, global_workgroup_size_, args_, resize_args_));
89+
}
4690
DispatchNode::encode(graph);
4791
}
4892

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode {
3232
const std::vector<ValueRef>&)>;
3333
using PickGlobalFn = const std::function<utils::uvec3(
3434
ComputeGraph*,
35+
const vkapi::ShaderInfo& shader,
3536
const std::vector<ArgGroup>&,
3637
const std::vector<ValueRef>&)>;
3738
using PickLocalFn = const std::function<utils::uvec3(
3839
ComputeGraph*,
40+
const vkapi::ShaderInfo& shader,
41+
const utils::uvec3& global_workgroup_size,
3942
const std::vector<ArgGroup>&,
4043
const std::vector<ValueRef>&)>;
4144

@@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode {
5154
const std::vector<ValueRef>& resize_args,
5255
const ResizeFunction& resize_fn = nullptr);
5356

57+
explicit DynamicDispatchNode(
58+
ComputeGraph& graph,
59+
const vkapi::ShaderInfo& shader,
60+
const PickGlobalFn& pick_global_wg_fn,
61+
const PickLocalFn& pick_local_wg_fn,
62+
const std::vector<ArgGroup>& args,
63+
const vkapi::ParamsBindList& params,
64+
const std::vector<PushConstantDataInfo>& push_constants,
65+
const vkapi::SpecVarList& spec_vars,
66+
const std::vector<ValueRef>& resize_args,
67+
const ResizeFunction& resize_fn = nullptr);
68+
5469
~DynamicDispatchNode() override = default;
5570

5671
void encode(ComputeGraph* graph) override;

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ExecuteNode {
6565
(void)graph;
6666
}
6767

68-
inline void trigger_resize(ComputeGraph* graph) {
68+
virtual inline void trigger_resize(ComputeGraph* graph) {
6969
if (resize_fn_ != nullptr) {
7070
resize_fn_(graph, args_, resize_args_);
7171
}

0 commit comments

Comments
 (0)