From caa3c70a0e1c3baad2aafddc2567e72f0294c587 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 26 May 2025 21:41:00 -0700 Subject: [PATCH] [ET-VK] Creating specialized version of conv2d pw shader for X and Y stride = 1 and padding = 0. This diff creates a specialized version of the conv2d pw shader for X and Y stride equals 1 and padding equals 0. * It adds a new file `conv2d_pw_s1p0.glsl`, which implements the conv2d pw shader for X and Y stride equals 1 and padding equals 0. * It adds a new file `conv2d_pw_s1p0.yaml`, which defines the parameters and shader variants for the specialized conv2d pw shader. * The file `Convolution.cpp` is modified to add a new parameter `stride_1_padding_0` to the `conv2d` function, which enables the use of the specialized shader. Differential Revision: [D75423931](https://our.internmc.facebook.com/intern/diff/D75423931/) [ghstack-poisoned] --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 163 ++++++++++++++++++ .../graph/ops/glsl/conv2d_pw_s1p0.yaml | 21 +++ .../runtime/graph/ops/impl/Convolution.cpp | 12 +- 3 files changed, 193 insertions(+), 3 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl new file mode 100644 index 00000000000..36c7a61eb3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -0,0 +1,163 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +#define TILE_SIZE_X ${TILE_SIZE_X} +#define TILE_SIZE_Y ${TILE_SIZE_Y} + +#define op(X, A, B) ${OPERATOR} + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")} +${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_limits; + ivec2 stride; + ivec2 padding; + int in_group_size; + int dummy_padding; + float out_min; + float out_max; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#extension GL_EXT_control_flow_attributes : require + +/* + * Computes a 2D pointwise convolution of an NxN output tile. Calculating an + * output tile for pointwise convolution is more efficient because the kernel + * size is only 1x1, making it easier to re-use loaded texels from t_kernel. + */ +void main() { + const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y}; + + const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]); + const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)}; + + // If the top left position is out of bounds, then this invocation will have + // no work to do. + if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) { + return; + } + + // Output position for TILE_SIZE = 2 + // +--------+--------+ + // | pos[0] | pos[1] | + // +--------+--------+ + // | pos[2] | pos[3] | + // +--------+--------+ + int pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; + for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) { + for (int x = 0; x < TILE_SIZE_X; ++x) { + pos[i * 2] = out_pos[0] * TILE_SIZE_X + x; + pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y; + i++; + } + } + + // Final output array where each element is a tensor value. + // Tuple of consecutive 4 elements represents a single output texel. + float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; + + const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0); + + // Initialize the output array with the bias value + for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) { + sum[i] = bias.x; + sum[i + 1] = bias.y; + sum[i + 2] = bias.z; + sum[i + 3] = bias.w; + } + + int z4 = 0; + // Since the kernel is 1x1, we only have to loop over the depth dimension. + for (int z = 0; z < in_group_size; z += 4, ++z4) { + // During prepacking, the weight tensor has been permuted so that the + // channel (IC) dim is along the x-axis, and the batch (OC) dim is along + // the z-axis. + float kernel_values[4 * 4]; // 4 channels, 4 elements per channel + + // Load kernel values from texels to array + [[unroll]] for (int i = 0; i < 4; ++i) { + const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0); + kernel_values[i * 4 + 0] = k_tex.x; + kernel_values[i * 4 + 1] = k_tex.y; + kernel_values[i * 4 + 2] = k_tex.z; + kernel_values[i * 4 + 3] = k_tex.w; + } + + for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { + const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); + // Load the input texel into an array + float tex_values[4]; + tex_values[0] = in_tex.x; + tex_values[1] = in_tex.y; + tex_values[2] = in_tex.z; + tex_values[3] = in_tex.w; + + // For 2x2 tile size algorithm works as follows. + // To explain the calculations below, the contents of one in_tex and the + // group of 4 texels loaded from t_kernel are shown: + // + // in_tex t_kernel + // -x-> ---x---> + // +---+ +----+----+----+----+ + // ^ | w | ^ | D0 | D1 | D2 | D3 | + // | +---+ | +----+----+----+----+ + // | | z | | | C0 | C1 | C2 | C3 | + // z +---+ z +----+----+----+----+ + // | | y | | | B0 | B2 | B2 | B3 | + // | +---+ | +----+----+----+----+ + // | x | | A0 | A1 | A2 | A3 | + // +---+ +----+----+----+----+ + // + // In the t_kernel graphic, cells sharing the same letter are from + // the same batch/output channel index, and the number denotes a unique + // channel index. To calculate the output texel, the following + // calculation is performed: + // + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | + // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ + // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // + // which is what is expressed in the following calculations. This is done + // for each output position. + for (int j = 0; j < 4; ++j) { + sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; + sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; + sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; + sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; + } + } + } + + for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { + const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]); + if (all(lessThan(pos_l, out_limits.xyz))) { + imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max)); + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml new file mode 100644 index 00000000000..ebfee11c405 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_pw_s1p0: + parameter_names_with_default_values: + OPERATOR: X + NDIM: 3 + DTYPE: float + TILE_SIZE_X: 1 + TILE_SIZE_Y: 4 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv2d_pw_s1p0 + - NAME: conv2d_pw_s1p0_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ba1f50a23c1..fbe4a61befc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -127,7 +127,8 @@ vkapi::ShaderInfo get_conv2d_shader( const Conv2dMethod method, const ValueRef weight, const bool clamp_out = false, - const bool stride_equals_dilation = false) { + const bool stride_equals_dilation = false, + const bool stride_1_padding_0 = false) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); switch (method) { @@ -150,7 +151,7 @@ vkapi::ShaderInfo get_conv2d_shader( if (prepack_weights) { kernel_name = "conv2d"; } else { - kernel_name = "conv2d_pw"; + kernel_name = stride_1_padding_0 ? "conv2d_pw_s1p0" : "conv2d_pw"; } break; case Conv2dMethod::SlidingWindow: @@ -382,6 +383,10 @@ void add_conv2d_node( (kernel_params.stride[0] == kernel_params.dilation[0] && kernel_params.stride[1] == kernel_params.dilation[1]); + const bool stride_1_padding_0 = + (kernel_params.stride[0] == 1 && kernel_params.stride[1] == 1 && + kernel_params.padding[0] == 0 && kernel_params.padding[1] == 0); + OutputParams out_params = {out_min_val, out_max_val}; check_conv2d_params(kernel_params, transposed_val); @@ -393,7 +398,8 @@ void add_conv2d_node( method, weight_data, clamp_out, - stride_equals_dilation); + stride_equals_dilation, + stride_1_padding_0); utils::uvec3 wg_size = create_conv2d_global_wg_size( graph, method, out, weight_data, stride_equals_dilation);