Skip to content
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
encode_execute();
}

} // namespace vkcompute
26 changes: 14 additions & 12 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset += push_constant.write(
push_constants_data.data(),
push_constants_offset,
kMaxPushConstantSize);
}
write_push_constant_data();

context->report_shader_dispatch_start(
shader_.kernel_name,
Expand All @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
node_id_);

vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
Expand All @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
pipeline_barrier,
shader_,
global_workgroup_size_,
push_constants_data.data(),
push_constants_offset);
push_constants_data_.data(),
push_constants_offset_);

context->report_shader_dispatch_end();
}

void DispatchNode::write_push_constant_data() {
push_constants_offset_ = 0;
for (const auto& push_constant : push_constants_) {
push_constants_offset_ += push_constant.write(
push_constants_data_.data(),
push_constants_offset_,
kMaxPushConstantSize);
}
}

} // namespace vkcompute
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;

// For push constants
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
uint32_t push_constants_offset_ = 0;

void write_push_constant_data();

public:
operator bool() const {
return shader_;
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ExecuteNode {
(void)graph;
}

inline void trigger_resize(ComputeGraph* graph) {
virtual inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
resize_fn_(graph, args_, resize_args_);
}
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,9 +1660,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
for (auto& new_sizes : new_sizes_list) {
graph.get_tensor(a.value)->virtual_resize(new_sizes);
graph.get_tensor(b.value)->virtual_resize(new_sizes);
graph.get_tensor(c)->virtual_resize(new_sizes);
graph.get_tensor(d.value)->virtual_resize(new_sizes);
graph.get_tensor(e)->virtual_resize(new_sizes);
graph.propagate_resize();

float val_a = new_sizes[1] + 4.0f;
float val_b = new_sizes[2] + 1.5f;
Expand Down
Loading