|
11 | 11 | // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName |
12 | 12 |
|
13 | 13 | #include <executorch/backends/vulkan/runtime/utils/MacroUtils.h> |
| 14 | +#include <executorch/backends/vulkan/runtime/utils/VecUtils.h> |
14 | 15 |
|
15 | 16 | #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h> |
16 | 17 | #include <executorch/backends/vulkan/runtime/vk_api/Command.h> |
@@ -150,7 +151,7 @@ class Context final { |
150 | 151 | void report_shader_dispatch_start( |
151 | 152 | const std::string& shader_name, |
152 | 153 | const utils::uvec3& global_wg_size, |
153 | | - const utils::uvec3& local_wg_size, |
| 154 | + const utils::WorkgroupSize& local_wg_size, |
154 | 155 | const uint32_t dispatch_id = UINT32_MAX); |
155 | 156 |
|
156 | 157 | /* |
@@ -189,13 +190,13 @@ class Context final { |
189 | 190 |
|
190 | 191 | vkapi::DescriptorSet get_descriptor_set( |
191 | 192 | const vkapi::ShaderInfo&, |
192 | | - const utils::uvec3&, |
| 193 | + const utils::WorkgroupSize&, |
193 | 194 | const vkapi::SpecVarList&, |
194 | 195 | const uint32_t push_constants_size); |
195 | 196 |
|
196 | 197 | inline vkapi::DescriptorSet get_descriptor_set( |
197 | 198 | const vkapi::ShaderInfo& shader_descriptor, |
198 | | - const utils::uvec3& local_work_group_size) { |
| 199 | + const utils::WorkgroupSize& local_work_group_size) { |
199 | 200 | return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u); |
200 | 201 | } |
201 | 202 |
|
@@ -362,14 +363,17 @@ inline bool Context::submit_compute_job( |
362 | 363 | report_shader_dispatch_start( |
363 | 364 | shader.kernel_name, |
364 | 365 | global_work_group, |
365 | | - local_work_group_size, |
| 366 | + utils::WorkgroupSize(local_work_group_size), |
366 | 367 | dispatch_id); |
367 | 368 |
|
368 | 369 | // Factor out template parameter independent code to minimize code bloat. |
369 | 370 | // Note that push constants are not exposed yet via this API, therefore the |
370 | 371 | // push constants size is assumed to be 0. |
371 | 372 | vkapi::DescriptorSet descriptor_set = get_descriptor_set( |
372 | | - shader, local_work_group_size, specialization_constants, 0u); |
| 373 | + shader, |
| 374 | + utils::WorkgroupSize(local_work_group_size), |
| 375 | + specialization_constants, |
| 376 | + 0u); |
373 | 377 |
|
374 | 378 | detail::bind( |
375 | 379 | descriptor_set, |
|
0 commit comments