diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..bb39c204c70b9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/conv2d_mm.h" +#include "core/providers/webgpu/nn/im2col_matmul.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/tensor/transpose.h" @@ -99,10 +100,34 @@ Status Conv::ComputeInternal(ComputeContext& context modified_input_output_shapes.push_back(bias->Shape()); } modified_input_output_shapes.push_back(TensorShape(output_shape_vector)); + + const auto input_height = input_shape[is_channels_last ? 1 : 2]; + const auto input_width = input_shape[is_channels_last ? 2 : 3]; + const auto input_channels = input_shape[is_channels_last ? 3 : 1]; + const auto kernel_height = kernel_shape[2]; + const auto kernel_width = kernel_shape[3]; + const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; + const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; + uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; std::vector updated_pads{pad0, pad1}; + + if (CanApplyIm2ColMatMulProgram(context, + is_channels_last, + activation_.activation_kind_, + kernel_shape, + conv_attrs_.auto_pad, + onnxruntime::narrow(conv_attrs_.group))) { + return ApplyIm2ColMatMulProgram(context, + is_channels_last, + dilations, + pads, + strides, + output); + } + if (conv_attrs_.group > 1) { Tensor transposed_kernel; if (is_channels_last) { @@ -128,13 +153,6 @@ Status Conv::ComputeInternal(ComputeContext& context } return context.RunProgram(program); } - const auto input_height = input_shape[is_channels_last ? 1 : 2]; - const auto input_width = input_shape[is_channels_last ? 2 : 3]; - const auto input_channels = input_shape[is_channels_last ? 3 : 1]; - const auto kernel_height = kernel_shape[2]; - const auto kernel_width = kernel_shape[3]; - const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; - const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0; if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) { diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc new file mode 100644 index 0000000000000..5012194aaf843 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include + +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/im2col_matmul.h" +#include "core/providers/webgpu/nn/activation_util.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { + +// TODO: move to common header. +template +inline T ceil_div(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + +// Chooses the optimal tile size (M, N) for the im2col operation. +// This tile size is performance-tuned and varies depending on the target device. +std::pair ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) { + // Define a list of preferred (tile_m, tile_n) pairs in descending order of preference. + const std::vector> kTileSizes = { + std::make_pair(32, 64), + std::make_pair(16, 64), + }; + + for (const auto& tile_pair : kTileSizes) { + const uint32_t tile_m = tile_pair.first; + const uint32_t tile_n = tile_pair.second; + + const uint32_t dispatch_m = ceil_div(im2col_m, tile_m); + const uint32_t dispatch_n = ceil_div(im2col_n, tile_n); + const uint32_t dispatch = dispatch_m * dispatch_n; + + if (dispatch >= 128) { + return tile_pair; + } + } + + // If none of the tile sizes meet the dispatch >=128 requirement, + return kTileSizes.back(); +} + +// Add support for more devices. +bool IsDeviceSupported(ComputeContext& context) { + const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); + + if (adapter_info.vendor == std::string_view("intel")) { + if (adapter_info.architecture == std::string_view("xe-2lpg")) { + return true; + } + } + + return false; +} + +} // namespace + +Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(src, src)); +} + +Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); + ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); + + return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), + WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), + WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(src, src), + WGSL_TEMPLATE_VARIABLE(weight, weight)); +} + +Status ApplyIm2ColMatMulProgram(ComputeContext& context, + bool is_channels_last, + const std::vector& dilations, + const std::vector& pads, + const std::vector& strides, + Tensor* output) { + const auto* input = context.Input(0); + const auto* weight = context.Input(1); + const bool has_bias = context.InputCount() > 2; + const auto* bias = has_bias ? context.Input(2) : nullptr; + + // Transpose OIHW Weight to OHWI + TensorShape weight_shape = weight->Shape(); + const uint32_t channel_output = onnxruntime::narrow(weight_shape[0]); + const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); + const uint32_t kernel_height = onnxruntime::narrow(weight_shape[2]); + const uint32_t kernel_width = onnxruntime::narrow(weight_shape[3]); + + TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input}; + Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape); + OIHW2OHWIProgram transpose_program{}; + transpose_program.SetWorkgroupSize(64); + + const uint32_t Ci_tiles = ceil_div(channel_input, 64u); + transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); + + transpose_program.AddInput({weight, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddOutput({&ohwi_weight, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddUniformVariables({{channel_output}, + {channel_input}, + {kernel_height}, + {kernel_width}, + {Ci_tiles}, + {ceil_div(kernel_height * kernel_height, 4u)}}); + ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); + + // im2col-matmul + const TensorShape input_shape = input->Shape(); + const TensorShape output_shape = output->Shape(); + + const uint32_t batch = onnxruntime::narrow(input_shape[0]); + const uint32_t input_height = onnxruntime::narrow(input_shape[is_channels_last ? 1 : 2]); + const uint32_t input_width = onnxruntime::narrow(input_shape[is_channels_last ? 2 : 3]); + const uint32_t output_height = onnxruntime::narrow(output_shape[is_channels_last ? 1 : 2]); + const uint32_t output_width = onnxruntime::narrow(output_shape[is_channels_last ? 2 : 3]); + + const uint32_t im2col_m = output_height * output_width; + const uint32_t im2col_k = kernel_height * kernel_width * channel_input; + const uint32_t im2col_n = channel_output; + + const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n); + const uint32_t workgroup_size = tile_n; + + // Check the device's subgroup size before shader compilation to avoid potential performance penalties + // associated with conditional checks in the shader runtime. + // Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`. + // If this condition is not met, the feature must be disabled. + const bool use_subgroup = true; + Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; + im2col_mm_program.SetWorkgroupSize(workgroup_size); + + const uint32_t M_tiles = ceil_div(im2col_m, tile_m); + const uint32_t N_tiles = ceil_div(im2col_n, tile_n); + im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch); + + im2col_mm_program.AddInput({input, + ProgramTensorMetadataDependency::TypeAndRank, + 4}); + im2col_mm_program.AddInput({&ohwi_weight, + ProgramTensorMetadataDependency::TypeAndRank, + 4}); + if (has_bias) { + im2col_mm_program.AddInput({bias, + ProgramTensorMetadataDependency::TypeAndRank}); + } + im2col_mm_program.AddOutput({output, + ProgramTensorMetadataDependency::TypeAndRank}); + im2col_mm_program.AddUniformVariables({{batch}, + {input_height}, + {input_width}, + {channel_input}, + {kernel_height}, + {kernel_width}, + {output_height}, + {output_width}, + {im2col_m}, + {im2col_k}, + {im2col_n}, + {M_tiles}, + {N_tiles}, + {ceil_div(ceil_div(im2col_k, 4u), 4u)}, + {dilations}, + {pads}, + {strides}}); + im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); + + return context.RunProgram(im2col_mm_program); +} + +bool CanApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const ActivationKind activation_kind, + const TensorShape weight_shape, + const AutoPadType auto_pad, + const uint32_t group) { + if (!IsDeviceSupported(context)) { + return false; + } + + // TODO: Support !is_channels_last + // TODO: Support fuse + // TODO: Support auto pad + // TODO: Support group conv + if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) { + return false; + } + + // TODO: Support conv1d + // TODO: Support conv2d_1x1 + const uint32_t kernel_height = onnxruntime::narrow(weight_shape[2]); + const uint32_t kernel_width = onnxruntime::narrow(weight_shape[3]); + if (kernel_height == 1 || kernel_width == 1) { + return false; + } + + // TODO: Support channel input vec1 + const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); + if (channel_input % 4 != 0) { + return false; + } + + return true; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h new file mode 100644 index 0000000000000..11b98db8554e4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +// Transpose OIHW Weight to OHWI +class OIHW2OHWIProgram final : public Program { + public: + OIHW2OHWIProgram() : Program("OIHW2OHWI") {} + + Status GenerateShaderCode(ShaderHelper& shader) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"O", ProgramUniformVariableDataType::Uint32}, + {"I", ProgramUniformVariableDataType::Uint32}, + {"H", ProgramUniformVariableDataType::Uint32}, + {"W", ProgramUniformVariableDataType::Uint32}, + {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, + {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); +}; + +class Im2ColMatMulProgram final : public Program { + public: + Im2ColMatMulProgram(bool has_bias, + uint32_t tile_m, + uint32_t tile_n, + bool use_subgroup) : Program("Im2ColMatMul"), + has_bias_(has_bias), + tile_m_(tile_m), + tile_n_(tile_n), + use_subgroup_(use_subgroup) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch", ProgramUniformVariableDataType::Uint32}, + {"src_h", ProgramUniformVariableDataType::Uint32}, + {"src_w", ProgramUniformVariableDataType::Uint32}, + {"channel_i", ProgramUniformVariableDataType::Uint32}, + {"kernel_h", ProgramUniformVariableDataType::Uint32}, + {"kernel_w", ProgramUniformVariableDataType::Uint32}, + {"output_h", ProgramUniformVariableDataType::Uint32}, + {"output_w", ProgramUniformVariableDataType::Uint32}, + {"im2col_m", ProgramUniformVariableDataType::Uint32}, + {"im2col_k", ProgramUniformVariableDataType::Uint32}, + {"im2col_n", ProgramUniformVariableDataType::Uint32}, + {"M_tiles", ProgramUniformVariableDataType::Uint32}, + {"N_tiles", ProgramUniformVariableDataType::Uint32}, + {"K_tiles", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; + + uint32_t tile_m_; + uint32_t tile_n_; + bool use_subgroup_; +}; + +bool CanApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const ActivationKind activation_kind, + const TensorShape kernel_shape, + const AutoPadType auto_pad, + const uint32_t group); + +Status ApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const std::vector& dilations, + const std::vector& pads, + const std::vector& strides, + Tensor* output); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template new file mode 100644 index 0000000000000..54fb94d42eea5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_bias +#param tile_m +#param tile_n +#param use_subgroup + +#use .getByOffset .setByOffset + +// im2col access for src: [N, H_i, W_i, C_i / 4] (vec4-packed NHWC) +// Conceptual Matrix Shape: N * (H_o * W_o) x (K_h * K_w * C_i / 4) +fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { + if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * 4 >= uniforms.im2col_k) { + return src_value_t(); + } + + let channel_i_v4 = uniforms.channel_i / 4; + + // 1. Decompose M index (H_o * W_o) into (h_idx, w_idx) + let h_idx = m / uniforms.output_w; // Output H index (H_o) + let w_idx = m % uniforms.output_w; // Output W index (W_o) + + // 2. Decompose K index into (k_h, k_w, c_i_v4_idx) + let c_i_v4_idx = k_packed_idx % channel_i_v4; + let k_h_w_idx = k_packed_idx / channel_i_v4; + let k_h = k_h_w_idx / uniforms.kernel_w; // Kernel Row + let k_w = k_h_w_idx % uniforms.kernel_w; // Kernel Column + + // 3. Calculate the coordinate in the padded input tensor + let src_h_coord_padded = h_idx * uniforms.strides.x + k_h * uniforms.dilations.x; + let src_w_coord_padded = w_idx * uniforms.strides.y + k_w * uniforms.dilations.y; + + // 4. Calculate the coordinate in the original input tensor + let src_h_coord : i32 = i32(src_h_coord_padded) - i32(uniforms.pads.x); + let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.z); + + // 5. Check for padding/out-of-bounds + if (src_h_coord < 0 || src_h_coord >= i32(uniforms.src_h) || + src_w_coord < 0 || src_w_coord >= i32(uniforms.src_w)) { + return src_value_t(); + } + + // 6. Calculate final NHWC/vec4 index + let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_v4 + + u32(src_h_coord) * uniforms.src_w * channel_i_v4 + + u32(src_w_coord) * channel_i_v4 + + c_i_v4_idx; + return src.getByOffset(src_idx); +} + +// weight shape: [Co, K_h, K_w, C_i / 4] (vec4-packed CoHWCi) +fn load_weight(n : u32, k_packed_idx : u32) -> weight_value_t { + if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / 4) { + let weight_idx = n * uniforms.im2col_k / 4 + + k_packed_idx; + return weight.getByOffset(weight_idx); + } + return weight_value_t(); +} + +#if has_bias +fn load_bias(n : u32) -> bias_element_t { + if (n < uniforms.im2col_n) { + return bias[n]; + } + return bias_element_t(); +} +#else +fn load_bias(n : u32) -> output_element_t { + return output_element_t(); +} +#endif + +// output shape: [N, H_o, W_o, C_o] (NHWC) +fn write_output(batch : u32, m : u32, n : u32, value : output_element_t) { + if (batch < uniforms.batch && m < uniforms.im2col_m && n < uniforms.im2col_n) { + let output_idx = batch * uniforms.im2col_m * uniforms.im2col_n + + m * uniforms.im2col_n + + n; + output.setByOffset(output_idx, value); + } +} + +const TILE_M_SIZE : u32 = tile_m; +const TILE_N_SIZE : u32 = tile_n; +const TILE_K_VEC_SIZE : u32 = 4; + +var src_tile : array, TILE_K_VEC_SIZE>; +var weight_tile : array, TILE_K_VEC_SIZE>; + +$MAIN { + let batch = workgroup_idx / (uniforms.M_tiles * uniforms.N_tiles); + let m_global_base = ((workgroup_idx / uniforms.N_tiles) % uniforms.M_tiles) * TILE_M_SIZE; + let n_global_base = (workgroup_idx % uniforms.N_tiles) * TILE_N_SIZE; + + var results : array; + for (var k_idx = 0u; k_idx < uniforms.K_tiles; k_idx++) { + for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 16u) { + // Loads a 16x4 vec of src into the workgroup memory. + let load_src_m = src_m + local_idx / 4; + let load_src_k = local_idx % 4; + + src_tile[load_src_k][load_src_m] = load_src(batch, + m_global_base + load_src_m, + k_idx * TILE_K_VEC_SIZE + load_src_k); + } + + for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 16u) { + // Loads a 16x4 vec of weight into the workgroup memory. + let load_weight_n = weight_n + local_idx / 4; + let load_weight_k = local_idx % 4; + + weight_tile[load_weight_k][load_weight_n] = load_weight(n_global_base + load_weight_n, + k_idx * TILE_K_VEC_SIZE + load_weight_k); + } + workgroupBarrier(); + + for (var inner_k_idx = 0u; inner_k_idx < TILE_K_VEC_SIZE; inner_k_idx++) { + let weight_data = weight_tile[inner_k_idx][local_idx]; +#if use_subgroup + let src_data = src_tile[inner_k_idx][sg_id]; + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + results[m_idx] += output_element_t(dot(weight_data, subgroupShuffle(src_data, m_idx))); + } +#else + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + results[m_idx] += output_element_t(dot(weight_data, src_tile[inner_k_idx][m_idx])); + } +#endif + } + workgroupBarrier(); + } + + let m_base = m_global_base; + let n_base = n_global_base + local_idx; + + let bias = load_bias(n_base); + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + var output_data = results[m_idx] + output_element_t(bias); + write_output(batch, m_base + m_idx, n_base, output_data); + } +} // MAIN diff --git a/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template b/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template new file mode 100644 index 0000000000000..dfbe7dde28b53 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#use .getByOffset .setByOffset + +fn load_src(co : u32, ci : u32, h_w : u32) -> src_element_t { + if (co < uniforms.O && ci < uniforms.I && h_w < uniforms.H * uniforms.W) { + let offset = co * uniforms.I * uniforms.H * uniforms.W + + ci * uniforms.H * uniforms.W + + h_w; + return src.getByOffset(offset); + } + return src_element_t(); +} + +fn write_output(co : u32, h_w : u32, ci : u32, value : output_element_t) { + if (co < uniforms.O && ci < uniforms.I && h_w < uniforms.H * uniforms.W) { + let offset = co * uniforms.H * uniforms.W * uniforms.I + + h_w * uniforms.I + + ci; + output.setByOffset(offset, value); + } +} + +var data_cache : array, 4>; + +$MAIN { + let group_co : u32 = workgroup_idx / uniforms.Ci_tiles; + let group_ci : u32 = (workgroup_idx % uniforms.Ci_tiles) * 64; + + if (group_co >= uniforms.O || group_ci >= uniforms.I) { + return; + } + + for (var h_w_idx = 0u; h_w_idx < uniforms.H_W_tiles; h_w_idx++) { + // load + for (var ci_idx = 0u; ci_idx < 64u; ci_idx += 16u) { + let load_ci_idx = ci_idx + local_idx / 4; + let load_h_w_idx = local_idx % 4; + + data_cache[load_h_w_idx][load_ci_idx] = load_src(group_co, + group_ci + load_ci_idx, + h_w_idx * 4 + load_h_w_idx); + } + workgroupBarrier(); + + // store + for (var local_h_w_idx = 0u; local_h_w_idx < 4u; local_h_w_idx++) { + let output_data = data_cache[local_h_w_idx][local_idx]; + write_output(group_co, h_w_idx * 4 + local_h_w_idx, group_ci + local_idx, output_data); + } + workgroupBarrier(); + } +} // MAIN