Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#define BATCH_SIZE_Y ${BATCH_SIZE_Y}

#define LOCAL_WG_SIZE 64

#define op(X, A, B) ${OPERATOR}

#include "indexing_utils.h"
Expand All @@ -30,14 +32,28 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
${layout_declare_ubo(4, "ivec3", "out_limits")}
${layout_declare_ubo(5, "ivec4", "in_sizes")}
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(push_constant) uniform restrict Block {
ivec4 out_limits;
ivec4 in_sizes;
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
ivec2 overlay_region;
int in_group_size;
int dummy_padding;
float out_min;
float out_max;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// For performance improvement, reduce register usage by caching positions in shared memory.
// Offset index by 1 every 16 points to avoid bank access conflict.
#define offset_pos_index(index) (index + ((index) >> 4))
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)];

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
Expand All @@ -63,6 +79,8 @@ void main() {
return;
}

pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;
Expand Down Expand Up @@ -109,18 +127,19 @@ void main() {
for (int j = 0; j < TILE_SIZE; j++, kx++) {
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
for (int s = 0; s < BATCH_SIZE_X; s++) {
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
}
}
}
}

const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
for (int y = 0; y < BATCH_SIZE_Y; y++) {
for (int x = 0; x < BATCH_SIZE_X; x++) {
if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits.xyz))) {
continue;
}
imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
${layout_declare_ubo(4, "ivec3", "out_limits")}
${layout_declare_ubo(5, "ivec4", "in_sizes")}
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(push_constant) uniform restrict Block {
ivec4 out_limits;
ivec4 in_sizes;
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
ivec2 overlay_region;
int in_group_size;
int dummy_padding;
float out_min;
float out_max;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand Down
86 changes: 38 additions & 48 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ void add_conv2d_node(
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
}

if (method == Conv2dMethod::Pointwise) {
vkapi::ParamsBindList param_buffers;
std::vector<PushConstantDataInfo> push_constants;
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
const utils::ivec4 kernel_param_size_stride = {
kernel_params.kernel_size[0],
kernel_params.kernel_size[1],
Expand All @@ -420,55 +422,43 @@ void add_conv2d_node(
kernel_params.dilation[0],
kernel_params.dilation[1]};

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
wg_size,
graph.create_local_wg_size(wg_size),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{},
// Specialization Constants
{},
// Resizing Logic
resize_conv2d_node,
{weight_data, stride, padding, dilation, transposed, output_padding},
{
graph.logical_limits_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
PushConstantDataInfo(
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
PushConstantDataInfo(
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
PushConstantDataInfo(&out_params, sizeof(out_params)),
}));
push_constants = {
graph.logical_limits_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
PushConstantDataInfo(
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
PushConstantDataInfo(
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
PushConstantDataInfo(&out_params, sizeof(out_params)),
};
} else {
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
wg_size,
graph.create_local_wg_size(wg_size),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{
t_out->logical_limits_ubo(),
t_in->sizes_ubo(),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(extra_params),
graph.create_params_buffer(out_params),
},
// Specialization Constants
{},
// Resizing Logic
resize_conv2d_node,
{weight_data, stride, padding, dilation, transposed, output_padding}));
param_buffers = {
t_out->logical_limits_ubo(),
t_in->sizes_ubo(),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(extra_params),
graph.create_params_buffer(out_params),
};
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
wg_size,
graph.create_local_wg_size(wg_size),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
param_buffers,
// Specialization Constants
{},
// Resizing Logic
resize_conv2d_node,
{weight_data, stride, padding, dilation, transposed, output_padding},
push_constants));
}

void add_conv1d_node(
Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,18 @@ void add_q_8w_linear_node(
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
// Create temporary tensors to store the width packed versions of mat1 and out
TmpTensor mat1_tmp(
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
TmpTensor out_tmp(
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
out_W_packed = out_tmp;
}
ValueRef q_mat2 = prepack_standard(
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
Expand Down