Skip to content
13 changes: 8 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}

layout(push_constant) uniform restrict Block {
$if STORAGE == "buffer":
${layout_declare_ubo(2, "int", "numel")}
int numel;
$else:
${layout_declare_ubo(2, "ivec3", "out_limits")}
${layout_declare_ubo(3, "float", "minimum")}
${layout_declare_ubo(4, "float", "maximum")}
ivec4 out_limits;
float minimum;
float maximum;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -53,7 +56,7 @@ void main() {
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
if (any(greaterThanEqual(pos, out_limits.xyz))) {
return;
}

Expand Down
19 changes: 8 additions & 11 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,7 @@ void add_unary_op_node(
add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));

vkapi::ParamsBindList ubos({});
if (graph.is_buffer_storage(out)) {
ubos.append({graph.numel_ubo(out)});
} else {
ubos.append({graph.logical_limits_ubo(out)});
}
ubos.append(
{graph.create_params_buffer(min), graph.create_params_buffer(max)});

const utils::vec2 min_max = {min, max};
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -60,9 +52,14 @@ void add_unary_op_node(
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
ubos,
// Push Constants
{},
// Push Constants
{
graph.is_buffer_storage(out) ? graph.numel_pc_of(out)
: graph.logical_limits_pc_of(out),
PushConstantDataInfo(&min_max, sizeof(min_max)),
},
// pcs,
// Specialization Constants
{},
// Resize Args
Expand Down
Loading