1010
1111#include < executorch/backends/vulkan/runtime/graph/Logging.h>
1212
13+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314#include < executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1415
1516#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
@@ -21,8 +22,8 @@ namespace vkcompute {
2122void resize_clone_node (
2223 ComputeGraph* graph,
2324 const std::vector<ArgGroup>& args,
24- const std::vector<ValueRef>& extra_args ) {
25- (void )extra_args ;
25+ const std::vector<ValueRef>& resize_args ) {
26+ (void )resize_args ;
2627 vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
2728 vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
2829 // TODO: support for when dimensionality doesn't match, i.e. clone is used to
@@ -41,11 +42,11 @@ void add_clone_node(
4142 std::string kernel_name = " clone" ;
4243 add_dtype_suffix (kernel_name, *t_out);
4344
44- graph.execute_nodes ().emplace_back (new DispatchNode (
45+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
4546 graph,
4647 VK_KERNEL_FROM_STR (kernel_name),
47- graph. create_global_wg_size (out) ,
48- graph. create_local_wg_size (out) ,
48+ default_pick_global_wg_size ,
49+ default_pick_local_wg_size ,
4950 // Inputs and Outputs
5051 {{out, vkapi::kWrite }, {in, vkapi::kRead }},
5152 // Parameter Buffers
@@ -60,6 +61,17 @@ void add_clone_node(
6061 resize_clone_node));
6162}
6263
64+ utils::uvec3 clone_image_to_buffer_global_wg_size (
65+ ComputeGraph* graph,
66+ const vkapi::ShaderInfo& shader,
67+ const std::vector<ArgGroup>& args,
68+ const std::vector<ValueRef>& resize_args) {
69+ (void )shader;
70+ (void )resize_args;
71+ const ValueRef image = args.at (1 ).refs .at (0 );
72+ return graph->create_global_wg_size (image);
73+ }
74+
6375void add_image_to_buffer_node (
6476 ComputeGraph& graph,
6577 const ValueRef image,
@@ -68,12 +80,11 @@ void add_image_to_buffer_node(
6880 add_dtype_suffix (kernel_name, graph.dtype_of (image));
6981 vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
7082
71- utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
72- graph.execute_nodes ().emplace_back (new DispatchNode (
83+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
7384 graph,
7485 shader,
75- global_wg_size ,
76- graph. create_local_wg_size (global_wg_size) ,
86+ clone_image_to_buffer_global_wg_size ,
87+ default_pick_local_wg_size ,
7788 // Input and Outputs
7889 {{buffer, vkapi::kWrite }, {image, vkapi::kRead }},
7990 // Parameter Buffers
@@ -96,12 +107,11 @@ void add_buffer_to_image_node(
96107 add_dtype_suffix (kernel_name, graph.dtype_of (image));
97108 vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
98109
99- utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
100- graph.execute_nodes ().emplace_back (new DispatchNode (
110+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
101111 graph,
102112 shader,
103- global_wg_size ,
104- graph. create_local_wg_size (global_wg_size) ,
113+ default_pick_global_wg_size ,
114+ default_pick_local_wg_size ,
105115 // Input and Outputs
106116 {{image, vkapi::kWrite }, {buffer, vkapi::kRead }},
107117 // Parameter Buffers
0 commit comments