Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
83b59e6
[ET-VK] Replace Uniform buffers with push constants for binary op
trivedivivek Dec 6, 2024
c91bd80
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 6, 2024
ecee457
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 6, 2024
2c76113
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 9, 2024
e6ff460
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 9, 2024
6104ce0
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 10, 2024
c5e22de
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 11, 2024
2ce975a
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 11, 2024
7767dbc
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 12, 2024
dec9a06
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 12, 2024
f1d1c4b
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 12, 2024
8b4e434
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 12, 2024
422aa90
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 16, 2024
b479914
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 17, 2024
3efa1e6
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 17, 2024
1b026ff
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 17, 2024
70e2d1a
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 17, 2024
4479bc9
Update on "[ET-VK] Replace Uniform buffers with push constants for b…
trivedivivek Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ layout(std430) buffer;
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "in_sizes")}
${layout_declare_ubo(B, "ivec4", "other_sizes")}
${layout_declare_ubo(B, "ivec2", "broadcast_params")}
${layout_declare_ubo(B, "float", "alpha")}

#include "broadcasting_utils.h"
#include "indexing_utils.h"
Expand All @@ -40,6 +35,14 @@ const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 in_sizes;
ivec4 other_sizes;
ivec2 broadcast_params;
float alpha;
};

void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
const ivec4 tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, packed_dim);
Expand Down
17 changes: 10 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ void add_binary_op_node(
alpha_val = graph.extract_scalar<float>(alpha);
}

const utils::ivec2 broadcast_params = create_broadcast_params(*t_in1, *t_in2);
const struct BinaryOpsParams {
const utils::ivec2 broadcast_params;
const float alpha_val;
} binary_ops_params{create_broadcast_params(*t_in1, *t_in2), alpha_val};

std::string kernel_name("binary_");
kernel_name.reserve(kShaderNameReserve);
Expand All @@ -83,16 +86,16 @@ void add_binary_op_node(
{{out, vkapi::MemoryAccessType::WRITE},
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{t_out->sizes_ubo(),
t_in1->sizes_ubo(),
t_in2->sizes_ubo(),
graph.create_params_buffer(broadcast_params),
graph.create_params_buffer(alpha_val)},
{},
// Specialization Constants
{t_out->hashed_layout(), t_in1->hashed_layout(), t_in2->hashed_layout()},
// Resizing Logic
resize_binary_op_node,
{}));
{},
{{graph.sizes_pc_of(out),
graph.sizes_pc_of(arg1),
graph.sizes_pc_of(arg2),
PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}}));
}

#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \
Expand Down
Loading