diff --git a/backends/vulkan/runtime/graph/ops/glsl/var_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/var_buffer.glsl new file mode 100644 index 00000000000..30f283d6f01 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/var_buffer.glsl @@ -0,0 +1,121 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "out_buf", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "in_buf", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "in_strides")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} + +layout(push_constant) uniform PushConstants { + int unbiased; +} pc; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int reduce_dim = 0; + +#define NWORKERS 4 +#define MAX_THREADS 16 + +shared T shared_sum[NWORKERS]; +shared T shared_sum_sq[NWORKERS]; +shared int shared_count[NWORKERS]; + +#include "indexing_utils.h" + +void main() { + const ivec4 out_idx = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z % out_sizes.z, + gl_GlobalInvocationID.z / out_sizes.z); + + const uint tid = gl_LocalInvocationID[reduce_dim]; + + shared_sum[tid] = T(0); + shared_sum_sq[tid] = T(0); + shared_count[tid] = 0; + barrier(); + + const int R = in_sizes[reduce_dim]; + const uint N = gl_WorkGroupSize[reduce_dim]; + + // Each workgroup processes a contiguous chunk of the input tensor + // along the reduce_dim. Each thread processes a part of this chunk. + uint q = R / N; + uint rem = R % N; + + uint len = q + (tid < rem ? 1u : 0u); + uint base = tid * q + min(tid, rem); + + T sum = T(0); + T sum_sq = T(0); + int count = 0; + + ivec4 in_idx = out_idx; + for (uint off = 0u; off < len; ++off) { + uint i = base + off; + in_idx[reduce_dim] = int(i); + + // out_idx is a 4D index, so for tensors with reduce_dim == 2, + // we need to set the reduce_dim + 1 to 0 as gl_GlobalInvocationID.z + // is influenced by the tid + if (reduce_dim == 2) { + in_idx[reduce_dim + 1] -= int(tid); + } + + T v = in_buf[tidx_to_bufi(in_idx, in_strides)]; + + sum += v; + sum_sq += v * v; + count += 1; + } + + shared_sum[tid] = sum; + shared_sum_sq[tid] = sum_sq; + shared_count[tid] = count; + barrier(); + + if (tid == 0u) { + T tot_sum = T(0); + T tot_sum_sq = T(0); + int tot_count = 0; + + for (uint i = 0; i < N; ++i) { + tot_sum += shared_sum[i]; + tot_sum_sq += shared_sum_sq[i]; + tot_count += shared_count[i]; + } + + T var; + if (tot_count > 0) { + T mean = tot_sum / T(tot_count); + var = (tot_sum_sq / T(tot_count)) - (mean * mean); + if (pc.unbiased != 0 && tot_count > 1) { + var *= T(tot_count) / T(tot_count - 1); + } + } else{ + // NaN to match PyTorch behavior + var = T(0.0/0.0); + } + + out_buf[tidx_to_bufi(out_idx, out_strides)] = var; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/var_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/var_buffer.yaml new file mode 100644 index 00000000000..7cb783775c9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/var_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +var_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: var_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.glsl new file mode 100644 index 00000000000..faeac01fcd2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.glsl @@ -0,0 +1,222 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "tin_limits")} +${layout_declare_ubo(B, "ivec4", "tin_sizes")} + +layout(push_constant) uniform PushConstants { + int unbiased; +} pc; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; +layout(constant_id = 4) const int reduce_dim = 0; +layout(constant_id = 5) const int group_dim = 1; + +// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of +// threads that will co-operate to compute one reduction output. There may be +// multiple groups computing distinct reduction outputs within one work group. +#define NWORKERS 4 + +// Sets an upper limit on the total size of a work group based on how many +// elements are allocated in the shared memory array below. Each thread in the +// work group will write into its assigned element in the shared array. +#define MAX_NTHREADS 16 + +shared VEC4_T shared_sum[MAX_NTHREADS]; +shared VEC4_T shared_sum_sq[MAX_NTHREADS]; +shared int shared_count[MAX_NTHREADS]; + +#include "indexing_utils.h" + +int tid_to_smi(const ivec2 tid) { + return tid.x + tid.y * NWORKERS; +} + +VEC4_T calculate_variance(VEC4_T sum, VEC4_T sum_sq, int count) { + VEC4_T mean = sum / float(count); + VEC4_T variance = (sum_sq / float(count)) - (mean * mean); + + if ((pc.unbiased != 0) && (count > 1)) { + variance = variance * (float(count) / float(count - 1.0)); + } + + return variance; +} + +void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + VEC4_T sum = VEC4_T(0); + VEC4_T sum_sq = VEC4_T(0); + int count = 0; + + scan_pos[reduce_dim] = tid.x; + for (int i = tid.x; i < tin_sizes[reduce_dim]; + i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { + VEC4_T val = load_texel(tin, scan_pos); + sum += val; + sum_sq += val * val; + count += 1; + } + // Write partial output to shared memory and synchronize work group + shared_sum[smi] = sum; + shared_sum_sq[smi] = sum_sq; + shared_count[smi] = count; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + int group_i = tid.y * NWORKERS; + sum = shared_sum[group_i]; + sum_sq = shared_sum_sq[group_i]; + count = shared_count[group_i]; + + for (int i = 1; i < NWORKERS; i++) { + int idx = tid.y * NWORKERS + i; + sum += shared_sum[idx]; + sum_sq += shared_sum_sq[idx]; + count += shared_count[idx]; + } + + // Determine if there are any padding elements in the final texel of the + // packed dimension + const int nspill = mod4(tin_sizes[packed_dim]); + // Detect if this thread is working on the final texels of the packed + // dimension, which may have padding elements + const bool is_last_texel = + scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); + + VEC4_T variance = calculate_variance(sum, sum_sq, count); + + // Explicitly set padding elements to 0 + if (is_last_texel && nspill > 0) { + [[unroll]] for (int i = nspill; i < 4; i++) { + variance[i] = 0; + } + } + + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, variance); + } +} + +/* + * Compute reduction where the reduction dim is also the packed dim. This case is + * complex because the reduction needs to occur over the individual texels. + * Therefore, in this algorithm each element of the accumulator texels are + * themselves partial outputs. Special care has to be taken to ignore padding + * elements in texels (which occur when the size of the packed dim is not a + * multiple of 4) so that they do not influence the output of reduction. + */ +void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + // Number of non-padding elements in the last texel in the reduction row + const int nspill = mod4(tin_sizes[packed_dim]); + // Only reduce up to the last "complete" texel. The last texel will need to be + // handled specially if it has padding elements. + const int reduce_len = tin_sizes[packed_dim] - nspill; + + VEC4_T sum = VEC4_T(0); + VEC4_T sum_sq = VEC4_T(0); + int count = 0; + + // Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of + // the reduction row + scan_pos[reduce_dim] = tid.x; + for (int i = tid.x * 4; i < reduce_len; + i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { + VEC4_T val = load_texel(tin, scan_pos); + sum += val; + sum_sq += val * val; + count += 4; + } + // For the last texel in the dim, if there are padding elements then each + // element of the texel needs to be processed individually such that the + // padding elements are ignored + if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) { + const VEC4_T val = load_texel(tin, scan_pos); + for (int i = 0; i < nspill; i++) { + sum.x += val[i]; + sum_sq.x += val[i] * val[i]; + count += 1; + } + } + // Write partial output to shared memory and synchronize work group + shared_sum[smi] = sum; + shared_sum_sq[smi] = sum_sq; + shared_count[smi] = count; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + sum = shared_sum[tid.y * NWORKERS]; + sum_sq = shared_sum_sq[tid.y * NWORKERS]; + count = shared_count[tid.y * NWORKERS]; + for (int i = 1; i < NWORKERS; i++) { + int idx = tid.y * NWORKERS + i; + sum += shared_sum[idx]; + sum_sq += shared_sum_sq[idx]; + count += shared_count[idx]; + } + + // Combine across the elements of the combined state + float total_sum = sum.x + sum.y + sum.z + sum.w; + float total_sum_sq = sum_sq.x + sum_sq.y + sum_sq.z + sum_sq.w; + int total_count = count; + + float mean = total_sum / float(total_count); + float variance = (total_sum_sq / float(total_count)) - (mean * mean); + + if ((pc.unbiased != 0) && (total_count > 1)) { + variance = variance * (float(total_count) / float(total_count - 1.0)); + } + + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, VEC4_T(variance, 0, 0, 0)); + } +} + +void main() { + ivec3 scan_pos = ivec3(gl_GlobalInvocationID); + scan_pos[reduce_dim] = 0; + + const ivec2 tid = ivec2( + gl_LocalInvocationID[reduce_dim], + gl_LocalInvocationID[group_dim]); + + if (any(greaterThanEqual(scan_pos, tin_limits))) { + return; + } + + if (reduce_dim != packed_dim) { + reduce_nonpacked_dim(tid, scan_pos); + } else { + reduce_packed_dim(tid, scan_pos); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.yaml new file mode 100644 index 00000000000..9cecbedca1a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/var_texture3d.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +var_texture3d: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: var_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Var.cpp b/backends/vulkan/runtime/graph/ops/impl/Var.cpp new file mode 100644 index 00000000000..41fdc41e982 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Var.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include +#include + +namespace vkcompute { + +using namespace utils; + +void resize_var_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + int dim = extra_args[0]; + + std::vector new_sizes = in->sizes(); + if (!new_sizes.empty()) { + new_sizes.at(normalize(dim, new_sizes.size())) = 1; + } + out->virtual_resize(new_sizes); +} + +void add_var_buffer_node( + ComputeGraph& graph, + ValueRef in, + const int dim, + bool unbiased, + ValueRef out) { + const int64_t ndim = graph.dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + // Check that the concat dim is not the reduction dim, if the tensor has a + // batch dim greater than 1 + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim); + } + + std::string kernel_name = "var"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const uint32_t nworkers_per_group = 4; + + utils::uvec3 global_wg_size = { + graph.size_at(-1, out), + graph.size_at(-2, out), + graph.size_at(-3, out) * graph.size_at(-4, out)}; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + + std::vector push_constants; + int32_t unbiased_int = static_cast(unbiased); + push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + { + graph.sizes_ubo(in), + graph.strides_ubo(in), + graph.sizes_ubo(out), + graph.strides_ubo(out), + }, + // Push Constants + push_constants, + // Specialization Constants + {reduce_dim}, + // Resize Args + {dim}, + // Resizing Logic + resize_var_node)); +} + +void add_var_texture_node( + ComputeGraph& graph, + ValueRef in, + const int dim, + bool unbiased, + ValueRef out) { + const int64_t ndim = graph.dim_of(in); + + int32_t reduce_dim = dim; + reduce_dim = normalize(reduce_dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + // Check that the concat dim is not the reduction dim, if the tensor has a + // batch dim greater than 1. + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim); + } + + std::string kernel_name = "var"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // This should match the value of MAX_NTHREADS in the softmax shader. + constexpr uint32_t max_nthreads = 16; + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[reduce_dim] = 1; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + const int other_dim_1 = (reduce_dim + 1) % 3; + const int other_dim_2 = (reduce_dim + 2) % 3; + int32_t group_dim; + if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { + local_wg_size[other_dim_1] = ngroups; + group_dim = other_dim_1; + } else { + local_wg_size[other_dim_2] = ngroups; + group_dim = other_dim_2; + } + + std::vector push_constants; + int32_t unbiased_int = static_cast(unbiased); + push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // shader_descriptor, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, + // Push Constants + push_constants, + // Specialization Constants + {graph.packed_dim_of(out), reduce_dim, group_dim}, + // Resize Args + {dim}, + // Resizing Logic + resize_var_node)); +} + +void add_var_node( + ComputeGraph& graph, + ValueRef in, + const int dim, + bool unbiased, + ValueRef out) { + bool is_buffer = graph.is_buffer_storage(in) || graph.is_buffer_storage(out); + + if (is_buffer) { + add_var_buffer_node(graph, in, dim, unbiased, out); + } else { + add_var_texture_node(graph, in, dim, unbiased, out); + } +} + +void var(ComputeGraph& graph, const std::vector& args) { + const IntListPtr dims_list = graph.get_int_list(args[1]); + VK_CHECK_COND(dims_list->size() == 1); + bool unbiased = true; + if (args.size() > 2) { + unbiased = graph.get_bool(args[2]); + } + return add_var_node( + graph, args[0], static_cast(dims_list->at(0)), unbiased, args[4]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.var.dim, var); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 4692a84a4c9..feac3ab42ec 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1140,6 +1140,64 @@ def get_reduce_op_inputs(): return test_suite +@register_test_suite(["aten.var.dim"]) +def get_var_inputs(): + test_cases = [] + shapes_and_dims = [ + ((L), 0), + ((L), -1), + ((M, L), 0), + ((M, L), 1), + ((L, M), -1), + ((M, L), -2), + ((S, S1, S2), 0), + ((S, S1, S2), 1), + ((S, S1, S2), 2), + ((S, S1, S2), -1), + ((S, S1, S2), -2), + ((S, S1, S2), -3), + ((1, S, S1, S2), 1), + ((1, S, S1, S2), 2), + ((1, S, S1, S2), 3), + ((1, S, S1, S2), -1), + ((1, S, S1, S2), -2), + ((1, S, S1, S2), -3), + # Test batches > 1 where the reduction dim is not the concat dim + ((S, L, S1, L), -1), + ((S, S2, S1, S), -2), + ((S, S2, M, M), 2), + ((S, M, S1, L), 3), + ] + + for i, (shape, dim) in enumerate(shapes_and_dims): + unbiased = (i % 2) == 0 + test_cases.append((shape, dim, unbiased, True)) + + # Texture-based tests + texture_test_suite = VkTestSuite(test_cases) + texture_test_suite.layouts = [ + "utils::kChannelsPacked", + "utils::kWidthPacked", + ] + texture_test_suite.storage_types = ["utils::kTexture3D"] + texture_test_suite.atol = "1e-4" + texture_test_suite.rtol = "1e-4" + texture_test_suite.test_name_suffix = "texture" + + # Buffer-based tests + buffer_test_suite = VkTestSuite(test_cases) + buffer_test_suite.layouts = [ + "utils::kChannelsPacked", + "utils::kWidthPacked", + ] + buffer_test_suite.storage_types = ["utils::kBuffer"] + buffer_test_suite.atol = "1e-4" + buffer_test_suite.rtol = "1e-4" + buffer_test_suite.test_name_suffix = "buffer" + + return [texture_test_suite, buffer_test_suite] + + @register_test_suite( [ "aten.sqrt.default",