diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index bed379c0c35..b63f89e299d 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { VK_THROW("Could not get dtype of value with type ", val.type()); } +bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (!is_buffer_storage(idx)) { + return false; + } + return is_contiguous(idx); +} + +bool ComputeGraph::is_standard_channels_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 2; +} + +bool ComputeGraph::is_standard_width_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 0; +} + ValueRef ComputeGraph::add_tensor( const std::vector& sizes, const vkapi::ScalarType dtype, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 21d80d5843f..eac632e6d35 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -231,7 +231,7 @@ class ComputeGraph final { inline ptr_type get_##short_name(const ValueRef idx) { \ return ptr_type(this, idx); \ } \ - inline bool val_is_##short_name(const ValueRef idx) { \ + inline bool val_is_##short_name(const ValueRef idx) const { \ return values_.at(idx).is##type_name(); \ } @@ -314,6 +314,32 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().has_buffer_storage(); } + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has buffer storage + * 3. The buffer backed tensor at `idx` has a contiguous memory layout + */ + bool is_contiguous_buffer_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is channels packed + */ + bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is width packed + */ + bool is_standard_width_packed_texture_tensor(const ValueRef idx) const; + inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base) const { return values_.at(maybe_view) @@ -354,7 +380,7 @@ class ComputeGraph final { return values_.at(idx).toTensor().numel_ubo(); } - inline bool has_standard_axis_map(const ValueRef idx) { + inline bool has_standard_axis_map(const ValueRef idx) const { return values_.at(idx).toTensor().has_standard_axis_map(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl new file mode 100644 index 00000000000..70fdf2bae17 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +#define BUF_T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "ivec4", "mean_strides")} +${layout_declare_ubo(B, "int", "mean_numel")} +${layout_declare_ubo(B, "ivec3", "in_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")} +const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout); + +#define LOCAL_WORK_GROUP_SIZE 64 +shared float shared_sum[LOCAL_WORK_GROUP_SIZE]; +shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE]; + +/* + * Computes the mean and standard deviation of one group of channels of the + * input tensor for the group normalization operator. + * + * Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors + * will have a shape of [G, N] where G = C / group. + * + * The input tensor is assumed to be a channels-packed texture tensor with the + * standard axis mapping. The output tensors are assumed to be contiguous buffer + * tensors. + * + * Algorithm: + * 1. Each shader invocation corresponds to one group in one batch + * 2. The local work group cooperatively reduces over all spatial locations (H×W) + * and all channels within the group (C/group channels) + * 3. Uses shared memory for efficient parallel reduction + * 4. Main thread (local ID 0) writes the final mean and rstd to buffer + * + * Global work group size: {N, 1, 1} + * N is the number of elements in the tensor buffer; each thread computes one + * output element. + * + * Local work group size: {1, float, 1} + * float should be a power of 2, recommended 64 or 128 threads. This allows + * efficient tree-based reduction in shared memory. Each local group will + * cooperate to compute the output element. + * + * Each shader invocation will compute the mean and standard deviation for one + * channel group in the input, and write out the corresponding result. + */ +void group_norm_reduce_C_packed() { + const int global_idx = int(gl_GlobalInvocationID.x); + const int local_idx = int(gl_LocalInvocationID.y); + + // Calculate group dimensions + const int D = in_sizes.z / group; // channels per group + const int HxW = in_sizes.y * in_sizes.x; // spatial size + const int group_size = D * HxW; // total elements per group + + // Convert global index to (group_idx, batch_idx) + const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order); + + // Initialize local sums + float local_sum = 0.0; + float local_sum_sq = 0.0; + int local_count = 0; + + // Calculate the range of channels for this group + const int group_start_channel = mean_tidx.x * D; + const int group_end_channel = group_start_channel + D; + + // Calculate the range of texels that contain channels from this group + const int start_texel_idx = group_start_channel / 4; + const int end_texel_idx = divup4(group_end_channel); + const int texels_in_group = end_texel_idx - start_texel_idx; + + // Total texels to process across all spatial locations + const int total_texels = texels_in_group * HxW; + + // Each thread processes a subset of texels + const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE; + const int start_texel = local_idx * texels_per_thread; + const int end_texel = min(start_texel + texels_per_thread, total_texels); + + // Process assigned texels + for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) { + // Convert texel index to spatial and channel coordinates + const int spatial_idx = texel_idx / texels_in_group; + const int texel_in_group = texel_idx % texels_in_group; + + // Convert to spatial coordinates + const int w = spatial_idx % in_sizes.x; + const int h = spatial_idx / in_sizes.x; + + // Calculate the global texel index + const int global_texel_idx = start_texel_idx + texel_in_group; + + // Convert to texture position using default axis mapping + ivec3 tex_pos = ivec3(w, h, global_texel_idx); + + // Adjust for batch dimension if needed + if (in_sizes.w > 1) { + // default axis mapping means channels is the batch concat dim + tex_pos.z += mean_tidx.y * divup4(in_sizes.z); + } + + // Check bounds and load texel + if (all(lessThan(tex_pos, in_limits))) { + const vec4 texel_val = load_texel(t_in, tex_pos); + + // Process all components of the texel that belong to this group + const int texel_start_channel = global_texel_idx * 4; + for (int comp = 0; comp < 4; comp++) { + const int current_channel = texel_start_channel + comp; + + // Check if this component belongs to the current group + if (current_channel >= group_start_channel && current_channel < group_end_channel) { + const float val = texel_val[comp]; + local_sum += val; + local_sum_sq += val * val; + local_count++; + } + } + } + } + + // Store local results in shared memory + shared_sum[local_idx] = local_sum; + shared_sum_sq[local_idx] = local_sum_sq; + + // Synchronize threads + memoryBarrierShared(); + barrier(); + + // Perform tree-based reduction in shared memory + for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) { + if (local_idx < stride) { + shared_sum[local_idx] += shared_sum[local_idx + stride]; + shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride]; + } + memoryBarrierShared(); + barrier(); + } + + // Main thread writes the result + if (local_idx == 0 && global_idx < mean_numel) { + const float total_sum = shared_sum[0]; + const float total_sum_sq = shared_sum_sq[0]; + const float count = float(group_size); + + // Calculate mean and reciprocal standard deviation + const float mean_val = total_sum / count; + const float variance = (total_sum_sq / count) - (mean_val * mean_val); + const float rstd_val = 1.0 / sqrt(variance + epsilon); + + // Write to buffer-backed tensors + t_mean[global_idx] = BUF_T(mean_val); + t_rstd[global_idx] = BUF_T(rstd_val); + } +} + +void main() { + group_norm_reduce_C_packed(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml new file mode 100644 index 00000000000..00c357a1d6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_reduce_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_reduce_texture diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl new file mode 100644 index 00000000000..8440481963a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl @@ -0,0 +1,129 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_rstd", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec3", "weight_limits")} +${layout_declare_ubo(B, "ivec4", "mean_strides")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Applies group normalization to t_in, and write the results to t_out. The mean + * and rstd of the input tensor are precomputed and passed in as t_mean and + * t_rstd. + * + * Given an input tensor t_in of shape [N, C, H, W], the mean and rstd will have + * shape [N, C / ngroup], and the output will have the same shape as t_in. The + * weight and bias tensor will have a shape of [C]. + * + * In this implementation, the input and output tensors are assumed to be + * channels packed textures with standard axis mapping. + * + * The weight and bias tensors are assumed to be width packed textures with + * standard axis mapping. + * + * The mean and rstd tensors are assumed to be contiguous buffer-backed tensors. + */ +void apply_group_norm() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Check bounds + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + // Convert texture position to tensor coordinates using default axis mapping + // and channels packing + ivec4 out_tidx = ivec4(pos.x, pos.y, mul4(pos.z), 0); + + // Handle batch dimension if batches > 1 + if (out_sizes.w > 1) { + const int C_aligned = alignup4(out_sizes.z); + // default axis mapping means channels is the batch concatenation dim + const int batch_idx = out_tidx.z / C_aligned; + out_tidx.w = batch_idx; + out_tidx.z = out_tidx.z % C_aligned; + } + + // Load input texel (contains 4 consecutive channels) + const vec4 input_texel = load_texel(t_in, pos); + + // Load weight and bias texels, which are width-packed; each element along the + // width dim corresponds to a channel in the input tensor. + const ivec3 weight_pos = ivec3(out_tidx.z / 4, 0, 0); + const vec4 weight_texel = load_texel(t_weight, weight_pos); + const vec4 bias_texel = load_texel(t_bias, weight_pos); + + // Calculate which channels this texel represents + // For channels-packed layout: texel at position z contains channels [z, z+1, z+2, z+3] + const int base_channel = out_tidx.z; + + // Calculate buffer indices for mean/rstd lookup + // Mean/rstd tensors have shape [G, N] where G = C/group + const int batch_idx = out_tidx.w; + const int channels_per_group = out_sizes.z / group; + + vec4 bias; + // Process each element of the output texel individually, since each element + // may belong to a different channel group + for (int i = 0; i < 4; ++i) { + const int channel_idx = base_channel + i; + // Handle case where padding channels are added + if (channel_idx >= out_sizes.z) { + bias[i] = input_texel[i]; + continue; + } + + // Calculate group index for this channel + const int group_idx = channel_idx / channels_per_group; + + // Create tensor index for mean/rstd buffer access + const ivec4 mean_tidx = ivec4(group_idx, batch_idx, 0, 0); + const int mean_bufi = tidx_to_bufi(mean_tidx, mean_strides); + + // Load mean and rstd values for this channel + const float mean_val = t_mean[mean_bufi]; + const float rstd_val = t_rstd[mean_bufi]; + + // Apply group normalization with weight and bias: ((input - mean) * rstd) * weight + bias + const float normalized = (input_texel[i] - mean_val) * rstd_val; + bias[i] = normalized * weight_texel[i] + bias_texel[i]; + } + + // Write result to output texture + write_texel(t_out, pos, bias); +} + +void main() { + apply_group_norm(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml new file mode 100644 index 00000000000..b50853be3b0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp new file mode 100644 index 00000000000..8d2a848b0c4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +std::vector calc_group_norm_mean_sizes( + api::vTensor& self, + const int64_t group) { + const std::vector& input_sizes = self.sizes(); + const int64_t N = input_sizes.at(0); + return {N, group}; +} + +utils::uvec3 group_norm_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + return {1, 64, 1}; +} + +void resize_group_norm_texture_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + // Extract tensor references from args + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const ValueRef mean = args.at(1).refs.at(3); + const ValueRef rstd = args.at(1).refs.at(4); + + // Extract group from resize args + const int64_t group_val = graph->extract_scalar(resize_args.at(0)); + + // Get input tensor sizes using ComputeGraph APIs + const std::vector in_sizes = graph->sizes_of(in); + + // Output tensor should have the same size as input + graph->virtual_resize(out, in_sizes); + + // Mean and rstd tensors should have size {num_batches, num_groups} + const int64_t N = in_sizes.at(0); // batch dimension + const std::vector mean_rstd_sizes = {N, group_val}; + + // Resize mean and rstd tensors + graph->virtual_resize(mean, mean_rstd_sizes); + graph->virtual_resize(rstd, mean_rstd_sizes); +} + +void add_native_group_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias_data, + const ValueRef N, + const ValueRef C, + const ValueRef HxW, + const ValueRef group, + const ValueRef eps, + const ValueRef out, + const ValueRef mean, + const ValueRef rstd) { + (void)N; + (void)C; + (void)HxW; + + const ValueRef arg_weight = prepack_standard( + graph, + weight_data, + graph.storage_type_of(in), + utils::kWidthPacked, + false); + const ValueRef arg_bias = prepack_standard( + graph, bias_data, graph.storage_type_of(in), utils::kWidthPacked, false); + + const int64_t group_val = graph.extract_scalar(group); + const float epsilon = graph.extract_scalar(eps); + + const std::vector in_sizes = graph.sizes_of(in); + + std::string kernel_name("group_norm_reduce_texture"); + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const struct { + int32_t group; + float epsilon; + } params_uniform = {static_cast(group_val), epsilon}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + group_norm_local_wg_size, + // Inputs and Outputs + {{{mean, rstd}, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + { + graph.strides_ubo(mean), + graph.numel_ubo(mean), + graph.logical_limits_ubo(in), + graph.sizes_ubo(in), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(mean), + }, + // Resize Args + {group}, + // Resizing Logic + nullptr)); + + // Compute element-wise normalization, now that mean and rstd have been + // computed. + std::string norm_kernel_name("group_norm_texture"); + norm_kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(norm_kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(norm_kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{in, arg_weight, arg_bias, mean, rstd}, vkapi::kRead}}, + // Shader params buffers + { + graph.logical_limits_ubo(out), + graph.sizes_ubo(out), + graph.logical_limits_ubo(arg_weight), + graph.strides_ubo(mean), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(in), + }, + // Resize Args + {group}, + // Resizing Logic + resize_group_norm_texture_node)); +} + +void native_group_norm(ComputeGraph& graph, const std::vector& args) { + // Assign each element of the args vector to const ValueRef variables + const ValueRef in = args.at(0); + const ValueRef weight_data = args.at(1); + const ValueRef bias_data = args.at(2); + const ValueRef N = args.at(3); + const ValueRef C = args.at(4); + const ValueRef HxW = args.at(5); + const ValueRef group = args.at(6); + const ValueRef eps = args.at(7); + const ValueRef out_tuple_ref = args.at(8); + + ValueRef out = kDummyValueRef; + ValueRef mean = kDummyValueRef; + ValueRef rstd = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + out = out_tuple->at(0); + mean = out_tuple->at(1); + rstd = out_tuple->at(2); + } + + VK_CHECK_COND(graph.val_is_tref(weight_data)); + VK_CHECK_COND(graph.val_is_tref(bias_data)); + + // Check expected storage types and memory layouts for tensor variables + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(in)); + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(out)); + + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(mean)); + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(rstd)); + + return add_native_group_norm_node( + graph, + in, + weight_data, + bias_data, + N, + C, + HxW, + group, + eps, + out, + mean, + rstd); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.native_group_norm.default, native_group_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 92f73268ebf..0fd5ef4f002 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -646,6 +646,45 @@ def get_native_layer_norm_inputs(): return test_suite +@register_test_suite("aten.native_group_norm.default") +def get_native_group_norm_inputs(): + test_suite = VkTestSuite( + [ + # (input_shape, weight_shape, bias_shape, N, C, HxW, group, eps) + # General test cases + ((1, 8, 4, 4), (8), (8), 1, 8, 16, 2, 0.001), + ((2, 8, 3, 3), (8), (8), 2, 8, 9, 4, 0.001), + ((1, 12, 2, 2), (12), (12), 1, 12, 4, 3, 0.001), + ((3, 16, 5, 5), (16), (16), 3, 16, 25, 8, 0.001), + ((3, 16, 13, 17), (16), (16), 3, 16, 13 * 17, 4, 0.001), + ((1, 4, 7, 7), (4), (4), 1, 4, 49, 2, 0.001), + ((2, 6, 1, 8), (6), (6), 2, 6, 8, 3, 0.001), + # Single group and prime number sizes + ((3, 7, 13, 11), (7), (7), 3, 7, 13 * 11, 1, 0.001), + # Each channel is it's own group and prime number sizes + ((1, 7, 13, 11), (7), (7), 1, 7, 13 * 11, 7, 0.001), + ] + ) + test_suite.layouts = [ + "utils::kChannelsPacked", + ] + test_suite.storage_types = [ + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + "at::kHalf", + ] + test_suite.arg_storage_types = { + "out": [None, "utils::kBuffer", "utils::kBuffer"], + } + + test_suite.prepacked_args = ["weight", "bias"] + test_suite.requires_prepack = True + + return test_suite + + def get_upsample_inputs(): inputs_list = [ # (input tensor shape, output 2D image size (H, W), output scaling factors)