Skip to content

Commit 56e131b

Browse files
authored
Using push constants for binary scalar op parameter.
Differential Revision: D88097606 Pull Request resolved: pytorch#16114
1 parent 41292e5 commit 56e131b

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2929
${layout_declare_ubo(B, "BufferMetadata", "outp")}
3030
${layout_declare_ubo(B, "BufferMetadata", "inp")}
3131

32-
${layout_declare_ubo(B, "float", "scalar_value")}
32+
layout(push_constant) uniform restrict Block {
33+
float scalar_value;
34+
};
3335

3436
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3537

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
3030
${layout_declare_ubo(B, "TextureMetadata", "outp")}
3131
${layout_declare_ubo(B, "TextureMetadata", "inp")}
3232

33-
${layout_declare_ubo(B, "float", "scalar_value")}
33+
layout(push_constant) uniform restrict Block {
34+
float scalar_value;
35+
};
3436

3537
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3638

backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ void add_binary_scalar_op_node(
4848
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
4949
add_dtype_suffix(kernel_name, graph.dtype_of(in));
5050

51-
vkapi::ParamsBindList param_ubos = {
52-
graph.meta_ubo(out),
53-
graph.meta_ubo(in),
54-
graph.create_params_buffer(scalar_val)};
51+
vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out), graph.meta_ubo(in)};
5552

5653
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5754
graph,
@@ -63,7 +60,7 @@ void add_binary_scalar_op_node(
6360
// Shader params buffers
6461
param_ubos,
6562
// Push Constants
66-
{},
63+
{PushConstantDataInfo(&scalar_val, sizeof(scalar_val))},
6764
// Specialization Constants
6865
{},
6966
// Resize Args

0 commit comments

Comments
 (0)