|
8 | 8 |
|
9 | 9 | #include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h> |
10 | 10 |
|
11 | | -#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> |
12 | | - |
13 | | -#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h> |
14 | | - |
15 | 11 | namespace vkcompute { |
16 | | - |
17 | 12 | ExecuteNode::ExecuteNode( |
18 | | - ComputeGraph& graph, |
19 | | - const vkapi::ShaderInfo& shader, |
20 | | - const utils::uvec3& global_workgroup_size, |
21 | | - const utils::uvec3& local_workgroup_size, |
22 | | - const std::vector<ArgGroup>& args, |
23 | | - const vkapi::ParamsBindList& params, |
24 | | - const vkapi::SpecVarList& spec_vars, |
25 | 13 | const ResizeFunction& resize_fn, |
26 | | - const std::vector<ValueRef>& resize_args) |
27 | | - : shader_(shader), |
28 | | - global_workgroup_size_(global_workgroup_size), |
29 | | - local_workgroup_size_(local_workgroup_size), |
| 14 | + const std::vector<ValueRef>& resize_args, |
| 15 | + const std::vector<ArgGroup>& args, |
| 16 | + const std::string& name) |
| 17 | + : resize_fn_(resize_fn), |
| 18 | + resize_args_(resize_args), |
30 | 19 | args_(args), |
31 | | - params_(params), |
32 | | - spec_vars_(spec_vars), |
33 | | - resize_fn_(resize_fn), |
34 | | - resize_args_(resize_args) { |
35 | | - graph.update_descriptor_counts(shader, /*execute = */ true); |
36 | | -} |
37 | | - |
38 | | -ExecuteNode::ExecuteNode( |
39 | | - const ResizeFunction& resize_fn, |
40 | | - const std::vector<ValueRef>& resize_args) |
41 | | - : shader_(), |
42 | | - global_workgroup_size_({0u, 0u, 0u}), |
43 | | - local_workgroup_size_({0u, 0u, 0u}), |
44 | | - args_(), |
45 | | - params_(), |
46 | | - spec_vars_(), |
47 | | - resize_fn_(resize_fn), |
48 | | - resize_args_(resize_args) {} |
49 | | - |
50 | | -void ExecuteNode::encode(ComputeGraph* graph) { |
51 | | - if (!shader_) { |
52 | | - return; |
53 | | - } |
54 | | - api::Context* const context = graph->context(); |
55 | | - vkapi::PipelineBarrier pipeline_barrier{}; |
56 | | - |
57 | | - std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock(); |
58 | | - |
59 | | - context->report_shader_dispatch_start( |
60 | | - shader_.kernel_name, |
61 | | - global_workgroup_size_, |
62 | | - local_workgroup_size_, |
63 | | - node_id_); |
64 | | - |
65 | | - vkapi::DescriptorSet descriptor_set = |
66 | | - context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_); |
67 | | - |
68 | | - uint32_t idx = 0; |
69 | | - idx = bind_values_to_descriptor_set( |
70 | | - graph, args_, pipeline_barrier, descriptor_set, idx); |
71 | | - |
72 | | - bind_params_to_descriptor_set(params_, descriptor_set, idx); |
73 | | - |
74 | | - context->register_shader_dispatch( |
75 | | - descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); |
76 | | - |
77 | | - context->report_shader_dispatch_end(); |
78 | | -} |
79 | | - |
| 20 | + name_(name) {} |
80 | 21 | } // namespace vkcompute |
0 commit comments