diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl index 1597b05e8d6..a58a4d3a457 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl @@ -56,13 +56,9 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); // weight = (out_C, in_C / G, K), // bias = (out_C,). // -// This implementation performs out_C shader invocations, where each invocation +// This implementation performs N x out_C x out_L shader invocations, where each invocation // calculates the rolling kernel of the length dimension for each batch, i.e., -// computes out_L * N results. -// -// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4) -// shader invocations, where each invocation computes 1 result. But that -// performs worse. +// computes out_L results. void main() { const ivec3 lpos = ivec3(gl_GlobalInvocationID); @@ -70,61 +66,53 @@ void main() { return; } - int in_length = in_sizes.x; - int batch_size = in_sizes.z; - // "out_c" is the output's channel index where we write our result. // Across shader invocations, this is the only value that varies. - int out_c = lpos.y; - VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map); + const int out_c = lpos.y; // "in_c" tracks the input's channel start index. // We iterate over the input group that corresponds to the output group. - int c_start = (out_c / out_group_size) * in_group_size; - int c_end = c_start + in_group_size; + const int c_start = (out_c / out_group_size) * in_group_size; + const int c_end = c_start + in_group_size; + + // "out_l" tracks the output's length index where we write our result. + const int out_l = lpos.x; + + // "N" is the batch index + const int N = lpos.z; // "in_l" tracks the input's length start index for our input-kernel overlay // region. - int l_start = -padding; - int l_end = in_length + padding - dilation * (kernel_size - 1); - - // Since the input/output tensors are channel-packed, which is along the - // batch dimension, we can batch-read/write four elements at a time. - for (int n = 0; n < batch_size; n += 4) { - // "out_l" tracks the output's length index where we write our result. - int out_l = 0; - - for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) { - VEC4_T sum = VEC4_T(0); - - for (int in_c = c_start; in_c < c_end; ++in_c) { - // "k" tracks the kernel's index for our input-kernel computation. - // It reads out-of-bound zeros, but trying to avoid them complicates - // for-loop conditions, which results in worse performance. - - // The weight tensor is channel-packed. It may not be trival choice for - // performance reason since need to have more data fetch. The reason is - // for some sequence model, we found that the weight tensor - // (out_channel, in_channel / group, kernel) often has a large - // out_channel >> kernel, leading to non-optimal use of memory as the - // weight tensor gets very deep. As a mitigation, we use channel-packing - // for the weight tensor, yielding a 75% reduction in weight-tensor - // memory. - - // It is possible to further reduce the memory footprint by swapping the - // dimensions, using x extent for out_channel, and y for kernel. - for (int k = 0; k < kernel_size; k += 1) { - const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4); - const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); - VEC4_T weight = VEC4_T(weight_texel[out_c % 4]); - - ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map); - sum = fma(weight, load_texel(t_in, in_pos), sum); - } - } - - const ivec3 out_lpos = ivec3(out_l, out_c, n / 4); - write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map); + const int in_l = out_l * stride - padding; + VEC4_T sum = VEC4_T(0); + + for (int in_c = c_start; in_c < c_end; ++in_c) { + // "k" tracks the kernel's index for our input-kernel computation. + // It reads out-of-bound zeros, but trying to avoid them complicates + // for-loop conditions, which results in worse performance. + + // The weight tensor is channel-packed. It may not be trival choice for + // performance reason since need to have more data fetch. The reason is + // for some sequence model, we found that the weight tensor + // (out_channel, in_channel / group, kernel) often has a large + // out_channel >> kernel, leading to non-optimal use of memory as the + // weight tensor gets very deep. As a mitigation, we use channel-packing + // for the weight tensor, yielding a 75% reduction in weight-tensor + // memory. + + // It is possible to further reduce the memory footprint by swapping the + // dimensions, using x extent for out_channel, and y for kernel. + for (int k = 0; k < kernel_size; k++) { + const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4); + const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); + VEC4_T weight = VEC4_T(weight_texel[out_c % 4]); + + const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map); + sum = fma(weight, load_texel(t_in, in_pos), sum); } } + + const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map); + const ivec3 out_lpos = ivec3(out_l, out_c, N); + write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 060f5028c02..6097a432148 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -505,17 +505,24 @@ void add_conv1d_node( check_conv_args(*t_in, *t_out); - int32_t in_channels = in_sizes.at(1); - int32_t out_channels = weight_sizes.at(0); - int32_t kernel_size = weight_sizes.at(2); - int32_t stride_size = graph.get_int_list(stride)->at(0); - int32_t padding_size = graph.get_int_list(padding)->at(0); - int32_t dilation_size = graph.get_int_list(dilation)->at(0); - int32_t in_group_size = static_cast(in_channels / groups_val); - int32_t out_group_size = static_cast(out_channels / groups_val); - - utils::uvec3 global_size = {1, static_cast(out_channels), 1}; - utils::uvec3 local_size = {1, 64, 1}; + const int32_t in_channels = in_sizes.at(1); + const int32_t out_channels = weight_sizes.at(0); + const int32_t kernel_size = weight_sizes.at(2); + const int32_t stride_size = graph.get_int_list(stride)->at(0); + const int32_t padding_size = graph.get_int_list(padding)->at(0); + const int32_t dilation_size = graph.get_int_list(dilation)->at(0); + const int32_t in_group_size = static_cast(in_channels / groups_val); + const int32_t out_group_size = + static_cast(out_channels / groups_val); + + const utils::uvec3 global_size = { + // out length + graph.size_at(-1, out), + // out channels + static_cast(out_channels), + // out batches + utils::div_up_4(graph.size_at(-3, out))}; + const utils::uvec3 local_size = graph.create_local_wg_size(global_size); Kernel1dParams kernel_params = { kernel_size, @@ -525,7 +532,7 @@ void add_conv1d_node( in_group_size, out_group_size}; - OutputParams out_params = {out_min_val, out_max_val}; + const OutputParams out_params = {out_min_val, out_max_val}; std::string kernel_name("conv1d"); if (clamp_out) {