Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/col2im.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#extension GL_EXT_debug_printf : enable
#define DEBUG_MODE

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)}
#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)}

$if OUTPUT_STORAGE == "buffer":
#define OUTPUT_BUFFER
$if INPUT_STORAGE == "buffer":
#define INPUT_BUFFER

#define TILE_M4 1
#define TILE_N4 1
#define TILE_K4 1

#define TILE_M 4
#define TILE_N 4
#define TILE_K 4

${define_required_extensions(DTYPE)}

#extension GL_EXT_debug_printf : enable

layout(std430) buffer;

#include "conv2d_common.glslh"

${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}

// Sizes of the convolution output image
${layout_declare_ubo(B, "ivec4", "output_sizes")}
// Sizes of the convolution input image
${layout_declare_ubo(B, "ivec4", "input_sizes")}
// Sizes of the im2col matrix of the convolution output
${layout_declare_ubo(B, "ivec4", "matrix_sizes")}

${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "conv2d_fp_im2col_block_store.glslh"

#ifdef INPUT_BUFFER

void load_matrix_tile(
out FPOutTile tile,
const int n4,
const int m_start,
const int N4) {
[[unroll]] for (int m = 0; m < TILE_M; m++) {
tile.data[m][0] = t_input[(m_start + m) * N4 + n4];
}
}

#else // INPUT_TEXTURE

void load_matrix_tile(
out FPOutTile tile,
const int n4,
const int m_start,
const int N4) {
[[unroll]] for (int m = 0; m < TILE_M; m++) {
tile.data[m][0] = texelFetch(
t_input, ivec3(n4, m_start + m, 0), 0);
}
}

#endif // INPUT_BUFFER

void main() {
// Each thread loads and writes a 4 wide x 4 high block of the matrix
const int n4 = int(gl_GlobalInvocationID.x);
const int m4 = int(gl_GlobalInvocationID.y);

const int n = mul_4(n4);
const int m = mul_4(m4);

if (n >= matrix_sizes.x || m >= matrix_sizes.y) {
return;
}

FPOutTile tile;

const int N4 = div_4(matrix_sizes.x);
load_matrix_tile(tile, n4, m, N4);
write_im2col_tile_as_image(tile, n4, m);
}
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/col2im.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

col2im:
parameter_names_with_default_values:
DTYPE: float
OUTPUT_STORAGE: texture3d
INPUT_STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: col2im_texture3d_buffer
- NAME: col2im_texture3d_texture3d
INPUT_STORAGE: texture3d
36 changes: 36 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef COMMON_GLSLH
#define COMMON_GLSLH

#define div_up_4(x) (((x) + 3) >> 2)

#define mul_4(x) ((x) << 2)
#define div_4(x) ((x) >> 2)

#define mod_4(x) ((x) & 3)

struct TensorIndex4D {
ivec4 data;
};

#ifdef DEBUG_MODE

void printTensorIndex4D(const TensorIndex4D index) {
debugPrintfEXT(
"tensor_idx: %d, %d, %d, %d\\n",
index.data.x,
index.data.y,
index.data.z,
index.data.w);
}

#endif // DEBUG_MODE

#endif // COMMON_GLSLH
36 changes: 36 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef CONV2D_COMMON_GLSLH
#define CONV2D_COMMON_GLSLH

#include "common.glslh"

struct Conv2DParams {
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
int groups;
};

#ifdef DEBUG_MODE

void printConv2DParams(const Conv2DParams params) {
debugPrintfEXT("Conv2DParams: \\n");
debugPrintfEXT(
" kernel_size: %d, %d\\n", params.kernel_size.x, params.kernel_size.y);
debugPrintfEXT(" stride: %d, %d\\n", params.stride.x, params.stride.y);
debugPrintfEXT(" padding: %d, %d\\n", params.padding.x, params.padding.y);
debugPrintfEXT(" dilation: %d, %d\\n", params.dilation.x, params.dilation.y);
debugPrintfEXT(" groups: %d\\n", params.groups);
}

#endif // DEBUG_MODE

#endif // CONV2D_COMMON_GLSLH
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef CONV2D_FP_IM2COL_BLOCK
#define CONV2D_FP_IM2COL_BLOCK

#extension GL_EXT_control_flow_attributes : require

#include "common.glslh"
#include "conv2d_common.glslh"

struct Im2ColTensorIdx {
int row;
int col;
};

void unwrap_m(out TensorIndex4D out_tidx_base, const int m) {
out_tidx_base.data[3] = m / (output_sizes.y * output_sizes.x);
out_tidx_base.data[1] = (m / output_sizes.x) % output_sizes.y;
out_tidx_base.data[0] = m % output_sizes.x;

// Initialize channels to 0; assume it will be set later on
out_tidx_base.data[2] = 0;
}

void im2col_tidx_to_output_tidx(
out TensorIndex4D output_tidx,
const Im2ColTensorIdx im2col_tidx) {
unwrap_m(output_tidx, im2col_tidx.row);
// Set channels
output_tidx.data.z = im2col_tidx.col;
}

void im2col_tidx_to_input_tidx(
out TensorIndex4D input_tidx,
const Im2ColTensorIdx im2col_tidx) {
// Use utility function to unwrap m index
TensorIndex4D output_tidx;
unwrap_m(output_tidx, im2col_tidx.row);

// Extract kernel position and channel from k index
// k is structured as: kernel_y * (kernel_width * channels) + kernel_x *
// channels + channel
const int input_channels = input_sizes.z;
const int channel_idx = im2col_tidx.col % input_channels;
const int kernel_x =
(im2col_tidx.col / input_channels) % conv2d_params.kernel_size.x;
const int kernel_y =
im2col_tidx.col / (input_channels * conv2d_params.kernel_size.x);

// Calculate input coordinates
const int input_y = int(output_tidx.data.y * conv2d_params.stride.y) -
int(conv2d_params.padding.y) + int(kernel_y * conv2d_params.dilation.y);
const int input_x = int(output_tidx.data.x * conv2d_params.stride.x) -
int(conv2d_params.padding.x) + int(kernel_x * conv2d_params.dilation.x);

input_tidx.data = ivec4(input_x, input_y, channel_idx, output_tidx.data.w);
}

// 4x4 block of the im2col matrix
struct FPIm2ColBlock {
VEC4_T data[4];
};

#endif // CONV2D_FP_IM2COL_BLOCK
Loading
Loading