Skip to content

Commit a3f0b00

Browse files
committed
[ET-VK] Implement generic reduction shader + mean, sum, amax, amin
## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) [ghstack-poisoned]
1 parent 7ce7526 commit a3f0b00

File tree

6 files changed

+411
-24
lines changed

6 files changed

+411
-24
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def __contains__(self, op):
8989
# Reduction
9090
exir_ops.edge.aten._log_softmax.default,
9191
exir_ops.edge.aten._softmax.default,
92+
exir_ops.edge.aten.mean.dim,
93+
exir_ops.edge.aten.sum.dim_IntList,
94+
exir_ops.edge.aten.amax.default,
95+
exir_ops.edge.aten.amin.default,
9296
# 2D Pooling
9397
exir_ops.edge.aten.avg_pool2d.default,
9498
exir_ops.edge.aten.max_pool2d_with_indices.default,
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
${define_active_storage_type(STORAGE)}
15+
16+
#extension GL_EXT_control_flow_attributes : require
17+
18+
layout(std430) buffer;
19+
20+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
22+
23+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
24+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
25+
26+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
27+
28+
layout(constant_id = 3) const int packed_dim = 0;
29+
layout(constant_id = 4) const int reduce_dim = 0;
30+
layout(constant_id = 5) const int group_dim = 1;
31+
32+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
33+
// threads that will co-operate to compute one reduction output. There may be
34+
// multiple groups computing distinct reduction outputs within one work group.
35+
#define NWORKERS 4
36+
37+
// Sets an upper limit on the total size of a work group based on how many
38+
// elements are allocated in the shared memory array below. Each thread in the
39+
// work group will write into its assigned element in the shared array.
40+
#define MAX_NTHREADS 16
41+
42+
43+
shared vec4 shared_vecs[MAX_NTHREADS];
44+
45+
#include "indexing_utils.h"
46+
47+
int tid_to_smi(const ivec2 tid) {
48+
return tid.x + tid.y * NWORKERS;
49+
}
50+
51+
/*
52+
* The functions below compute reduction along a single dimension for a tensor.
53+
* The shader template generalize reduction by abstracting the initial value of
54+
* the accumulator, the calculation used to update the accumulator with new
55+
* values, and a postprocessing calculation that can be used to modify the
56+
* accumulator before writing to output.
57+
*
58+
* This shader also utilize shared memory to have multiple threads help compute
59+
* the max and sum reduction operations. A total of NGROUPS x NWORKERS threads
60+
* are expected to be launched. Each group works on a unique reduction "row", and
61+
* within a group NWORKERS threads co-operate to compute the max and sum of one
62+
* "row". Each worker in the group is responsible for computing a partial output
63+
* of the "row" and uploading it to shared memory; the overall reduction output
64+
* can then be determined by aggregating the partial outputs stored in shared
65+
* memory.
66+
*
67+
* As a caveat, this shader does not currently support cases where `batch` > 1
68+
* and the reduce dim happens to also be the batch concatenation dim. To support
69+
* this, there will need to be additional logic to set the starting value of
70+
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
71+
* supporting this case is left as an exercise for when it is required.
72+
*/
73+
74+
// Initializing the accumulator accepts the first value in the reduction row,
75+
// since some reduction operations (i.e. amax, amin) prefer to initialize with
76+
// a data point instead of a static value.
77+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
78+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
79+
// Useful for operators such as mean which want to perform a final calculation
80+
// with the accumulator.
81+
#define POSTPROCESS(accum) ${POSTPROCESS}
82+
83+
/*
84+
* Computes reduction where the reduction dim is orthogonal to the packed dim.
85+
* This case is simpler because each element of a texel belongs to a separate
86+
* reduction "group", meaning we don't have to perform reduction along a texel.
87+
*/
88+
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
89+
// shared memory index of this thread
90+
const int smi = tid_to_smi(tid);
91+
92+
scan_pos[reduce_dim] = 0;
93+
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
94+
95+
scan_pos[reduce_dim] = tid.x;
96+
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
97+
// the reduction row
98+
for (int i = tid.x; i < tin_sizes[reduce_dim];
99+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
100+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
101+
}
102+
// Write partial output to shared memory and synchronize work group
103+
shared_vecs[smi] = accum;
104+
barrier();
105+
106+
// Since the reduction row is reduced to only one element, only the "main"
107+
// thread in the group needs aggregate the partial outputs
108+
if (tid.x == 0) {
109+
// Iterate over the partial outputs to obtain the overall output
110+
int group_i = tid.y * NWORKERS;
111+
accum = shared_vecs[group_i++];
112+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
113+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
114+
}
115+
116+
// Determine if there are any padding elements in the final texel of the
117+
// packed dimension
118+
const int nspill = mod4(tin_sizes[packed_dim]);
119+
// Detect if this thread is working on the final texels of the packed
120+
// dimension, which may have padding elements
121+
const bool is_last_texel =
122+
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
123+
124+
// Explicitly set padding elements to 0
125+
if (is_last_texel && nspill > 0) {
126+
[[unroll]] for (int i = nspill; i < 4; ++i) {
127+
accum[i] = 0;
128+
}
129+
}
130+
scan_pos[reduce_dim] = tid.x;
131+
write_texel(tout, scan_pos, POSTPROCESS(accum));
132+
}
133+
}
134+
135+
/*
136+
* Compute reduction where the reduction dim is also the packed dim. This case is
137+
* complex because the reduction needs to occur over the individual texels.
138+
* Therefore, in this algorithm each element of the accumulator texels are
139+
* themselves partial outputs. Special care has to be taken to ignore padding
140+
* elements in texels (which occur when the size of the packed dim is not a
141+
* multiple of 4) so that they do not influence the output of reduction.
142+
*/
143+
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
144+
// shared memory index of this thread
145+
const int smi = tid_to_smi(tid);
146+
147+
// Number of non-padding elements in the last texel in the reduction row
148+
const int nspill = mod4(tin_sizes[packed_dim]);
149+
// Only reduce up to the last "complete" texel. The last texel will need to be
150+
// handled specially if it has padding elements.
151+
const int reduce_len = tin_sizes[packed_dim] - nspill;
152+
153+
scan_pos[reduce_dim] = 0;
154+
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));
155+
156+
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
157+
// the reduction row
158+
scan_pos[reduce_dim] = tid.x;
159+
for (int i = tid.x * 4; i < reduce_len;
160+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
161+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
162+
}
163+
// For the last texel in the dim, if there are padding elements then each
164+
// element of the texel needs to be processed individually such that the
165+
// padding elements are ignored
166+
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
167+
const vec4 intex = load_texel(tin, scan_pos);
168+
for (int i = 0; i < nspill; ++i) {
169+
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
170+
}
171+
}
172+
// Write partial output to shared memory and synchronize work group
173+
shared_vecs[smi] = accum;
174+
barrier();
175+
176+
// Since the reduction row is reduced to only one element, only the "main"
177+
// thread in the group needs aggregate the partial outputs
178+
if (tid.x == 0) {
179+
// Iterate over the partial maximums to obtain the overall maximum
180+
int group_i = tid.y * NWORKERS;
181+
accum = shared_vecs[group_i++];
182+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
183+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
184+
}
185+
// Each element of the texel is itself a partial maximum; iterate over the
186+
// texel to find the actual maximum
187+
float accum_final = accum.x;
188+
[[unroll]] for (int i = 1; i < 4; ++i) {
189+
accum_final = UPDATE_ACCUM(accum[i], accum_final);
190+
}
191+
192+
scan_pos[reduce_dim] = tid.x;
193+
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
194+
}
195+
}
196+
197+
void main() {
198+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
199+
scan_pos[reduce_dim] = 0;
200+
201+
const ivec2 tid = ivec2(
202+
gl_LocalInvocationID[reduce_dim],
203+
gl_LocalInvocationID[group_dim]);
204+
205+
if (any(greaterThanEqual(scan_pos, tin_limits))) {
206+
return;
207+
}
208+
209+
if (reduce_dim != packed_dim) {
210+
reduce_nonpacked_dim(tid, scan_pos);
211+
} else {
212+
reduce_packed_dim(tid, scan_pos);
213+
}
214+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
INIT_ACCUM: VEC4_T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: sum
20+
- NAME: mean
21+
POSTPROCESS: (accum / tin_sizes[reduce_dim])
22+
- NAME: amax
23+
INIT_ACCUM: first_val
24+
UPDATE_ACCUM: max(accum, new_val)
25+
POSTPROCESS: accum
26+
- NAME: amin
27+
INIT_ACCUM: first_val
28+
UPDATE_ACCUM: min(accum, new_val)
29+
POSTPROCESS: accum
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
namespace vkcompute {
17+
18+
using namespace utils;
19+
20+
void resize_reduce_node(
21+
ComputeGraph* graph,
22+
const std::vector<ArgGroup>& args,
23+
const std::vector<ValueRef>& extra_args) {
24+
(void)extra_args;
25+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
26+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
27+
28+
std::vector<int64_t> in_sizes = in->sizes();
29+
// out->virtual_resize(in_sizes);
30+
}
31+
32+
void add_reduce_node(
33+
ComputeGraph& graph,
34+
ValueRef in,
35+
const int dim,
36+
ValueRef out,
37+
const std::string& op_name) {
38+
VK_CHECK_COND(
39+
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
40+
"Vulkan reduction only supports texture storage");
41+
42+
const int64_t ndim = graph.dim_of(in);
43+
44+
int32_t reduce_dim = dim;
45+
reduce_dim = normalize(reduce_dim, ndim);
46+
reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim);
47+
48+
// Check that the concat dim is not the reduction dim, if the tensor has a
49+
// batch dim greater than 1.
50+
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
51+
VK_CHECK_COND(
52+
graph.concat_dim_of(in) != reduce_dim,
53+
"Reduce shader currently does not support concat dim == reduce dim");
54+
VK_CHECK_COND(
55+
graph.concat_dim_of(out) != reduce_dim,
56+
"Reduce shader currently does not support concat dim == reduce dim");
57+
}
58+
59+
vkapi::ShaderInfo shader_descriptor;
60+
std::string kernel_name = op_name;
61+
kernel_name.reserve(kShaderNameReserve);
62+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
63+
64+
// This should match the value of MAX_NTHREADS in the softmax shader.
65+
constexpr uint32_t max_nthreads = 16;
66+
67+
const uint32_t nworkers_per_group = 4;
68+
const uint32_t ngroups = 4;
69+
VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads);
70+
71+
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
72+
global_wg_size[reduce_dim] = 1;
73+
74+
utils::uvec3 local_wg_size{1, 1, 1};
75+
local_wg_size[reduce_dim] = nworkers_per_group;
76+
const int other_dim_1 = (reduce_dim + 1) % 3;
77+
const int other_dim_2 = (reduce_dim + 2) % 3;
78+
int32_t group_dim;
79+
if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) {
80+
local_wg_size[other_dim_1] = ngroups;
81+
group_dim = other_dim_1;
82+
} else {
83+
local_wg_size[other_dim_2] = ngroups;
84+
group_dim = other_dim_2;
85+
}
86+
87+
graph.execute_nodes().emplace_back(new DispatchNode(
88+
graph,
89+
// shader_descriptor,
90+
VK_KERNEL_FROM_STR(kernel_name),
91+
global_wg_size,
92+
local_wg_size,
93+
// Inputs and Outputs
94+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
95+
// Shader params buffers
96+
{graph.logical_limits_ubo(in), graph.sizes_ubo(in)},
97+
// Specialization Constants
98+
{graph.packed_dim_of(out), reduce_dim, group_dim},
99+
// Resizing Logic
100+
resize_reduce_node));
101+
}
102+
103+
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
104+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
105+
const IntListPtr dims_list = graph.get_int_list(args[1]); \
106+
VK_CHECK_COND(dims_list->size() == 1); \
107+
return add_reduce_node( \
108+
graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \
109+
}
110+
111+
DEFINE_REDUCE_FN(sum, 4)
112+
DEFINE_REDUCE_FN(mean, 4)
113+
DEFINE_REDUCE_FN(amax, 3)
114+
DEFINE_REDUCE_FN(amin, 3)
115+
116+
REGISTER_OPERATORS {
117+
VK_REGISTER_OP(aten.sum.dim_IntList, sum);
118+
VK_REGISTER_OP(aten.mean.dim, mean);
119+
VK_REGISTER_OP(aten.amax.default, amax);
120+
VK_REGISTER_OP(aten.amin.default, amin);
121+
}
122+
123+
} // namespace vkcompute

0 commit comments

Comments
 (0)