diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index b645905939f..bb7ce482a7a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -25,12 +25,15 @@ layout(std430) buffer; ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} + +layout(push_constant) uniform restrict Block { $if STORAGE == "buffer": - ${layout_declare_ubo(2, "int", "numel")} + int numel; $else: - ${layout_declare_ubo(2, "ivec3", "out_limits")} -${layout_declare_ubo(3, "float", "minimum")} -${layout_declare_ubo(4, "float", "maximum")} + ivec4 out_limits; +float minimum; +float maximum; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -53,7 +56,7 @@ void main() { void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_limits))) { + if (any(greaterThanEqual(pos, out_limits.xyz))) { return; } diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 518148f12eb..ea8daf2ea64 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -43,15 +43,7 @@ void add_unary_op_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - vkapi::ParamsBindList ubos({}); - if (graph.is_buffer_storage(out)) { - ubos.append({graph.numel_ubo(out)}); - } else { - ubos.append({graph.logical_limits_ubo(out)}); - } - ubos.append( - {graph.create_params_buffer(min), graph.create_params_buffer(max)}); - + const utils::vec2 min_max = {min, max}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -60,9 +52,14 @@ void add_unary_op_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers - ubos, - // Push Constants {}, + // Push Constants + { + graph.is_buffer_storage(out) ? graph.numel_pc_of(out) + : graph.logical_limits_pc_of(out), + PushConstantDataInfo(&min_max, sizeof(min_max)), + }, + // pcs, // Specialization Constants {}, // Resize Args