Skip to content
Merged
5 changes: 4 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* output at a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 pos = ivec3(
gl_GlobalInvocationID.x % out_limits.x,
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));

if (any(greaterThanEqual(pos, out_limits))) {
return;
Expand Down
25 changes: 15 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,40 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const u16vec3 pos = u16vec3(
gl_GlobalInvocationID.x % out_limits.x,
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;
const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding);

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so any reads from the padding region is skipped.
const ivec2 start = ipos;
const ivec2 end = ipos + overlay_region.xy;
const u16vec2 start = ipos;
const u16vec2 end = ipos + u16vec2(overlay_region.xy);

VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
uint16_t kx = uint16_t(0);
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) {
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum);
kx++;
}
}
Expand Down
27 changes: 19 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,42 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];

/*
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
* output tile for pointwise convolution is more efficient because the kernel
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const u16vec3 gpos = u16vec3(gl_GlobalInvocationID);
const uvec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;

const u16vec3 gpos = u16vec3(
gl_GlobalInvocationID.x % out_limits_scaled.x,
(gl_GlobalInvocationID.x / out_limits_scaled.x) % out_limits_scaled.y,
gl_GlobalInvocationID.x / (out_limits_scaled.x * out_limits_scaled.y));

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
u16vec3 pos[TILE_SIZE * TILE_SIZE];
u16vec2 pos[TILE_SIZE * TILE_SIZE];
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
for (int x = 0; x < TILE_SIZE; ++x) {
pos[i] = u16vec3(
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
pos[i] = u16vec2(
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
i++;
}
}

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (any(greaterThanEqual(pos[0], out_limits))) {
if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) {
return;
}

Expand All @@ -68,7 +78,7 @@ void main() {
// the top-left element is in a region added by padding.
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
}

vec4 sum[TILE_SIZE * TILE_SIZE];
Expand Down Expand Up @@ -133,8 +143,9 @@ void main() {
}

for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
if (all(lessThan(pos[i], out_limits))) {
imageStore(t_out, pos[i], op(sum[i], out_min, out_max));
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
}
}
}
10 changes: 8 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,17 @@ void add_conv2d_node(
weight_data,
clamp_out);

utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);

if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
create_conv2d_global_wg_size(graph, method, out),
graph.create_local_wg_size(out),
wg_size,
graph.create_local_wg_size(wg_size),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
Expand Down
Loading