Skip to content

Commit b9e001b

Browse files
committed
[ET-VK] Add 2D Reduction to Vulkan Backend
1 parent f332196 commit b9e001b

File tree

4 files changed

+256
-7
lines changed

4 files changed

+256
-7
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def register_reduce_op(features: OpFeatures):
532532

533533
def check_reduce_node(node: torch.fx.Node) -> bool:
534534
dim_list = node.args[1]
535-
if isinstance(dim_list, list) and len(dim_list) != 1:
535+
if isinstance(dim_list, list) and len(dim_list) > 2:
536536
return False
537537

538538
keepdim = node.args[2]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
${define_active_storage_type(STORAGE)}
15+
16+
#extension GL_EXT_control_flow_attributes : require
17+
18+
layout(std430) buffer;
19+
20+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
22+
23+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
24+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
25+
26+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
27+
28+
layout(constant_id = 3) const int packed_dim = 0;
29+
layout(constant_id = 4) const int reduce_dim1 = 0;
30+
layout(constant_id = 5) const int reduce_dim2 = 1;
31+
layout(constant_id = 6) const int group_dim = 2;
32+
33+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
34+
// threads that will co-operate to compute one reduction output. There may be
35+
// multiple groups computing distinct reduction outputs within one work group.
36+
#define NWORKERS 4
37+
38+
// Sets an upper limit on the total size of a work group based on how many
39+
// elements are allocated in the shared memory array below. Each thread in the
40+
// work group will write into its assigned element in the shared array.
41+
#define MAX_NTHREADS 16
42+
43+
44+
shared vec4 shared_vecs[MAX_NTHREADS];
45+
46+
#include "indexing_utils.h"
47+
48+
int tid_to_smi(const ivec2 tid) {
49+
return tid.x + tid.y * NWORKERS;
50+
}
51+
52+
// Initializing the accumulator accepts the first value in the reduction row,
53+
// since some reduction operations (i.e. amax, amin) prefer to initialize with
54+
// a data point instead of a static value.
55+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
56+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
57+
// Useful for operators such as mean which want to perform a final calculation
58+
// with the accumulator.
59+
#define POSTPROCESS(accum) ${POSTPROCESS}
60+
61+
void reduce_2d(const ivec2 tid, ivec3 scan_pos) {
62+
// shared memory index of this thread
63+
const int smi = tid_to_smi(tid);
64+
65+
scan_pos[reduce_dim1] = 0;
66+
scan_pos[reduce_dim2] = 0;
67+
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
68+
69+
// First dimension reduction
70+
scan_pos[reduce_dim1] = tid.x;
71+
for (int i = tid.x; i < tin_sizes[reduce_dim1];
72+
i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) {
73+
74+
// Second dimension reduction
75+
scan_pos[reduce_dim2] = 0;
76+
for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) {
77+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
78+
}
79+
}
80+
81+
// Write partial output to shared memory and synchronize
82+
shared_vecs[smi] = accum;
83+
barrier();
84+
85+
// Main thread aggregates results
86+
if (tid.x == 0) {
87+
// Iterate over the partial outputs to obtain the overall output
88+
int group_i = tid.y * NWORKERS;
89+
accum = shared_vecs[group_i++];
90+
for (int i = 1; i < NWORKERS; i++, group_i++) {
91+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
92+
}
93+
94+
// Determine if there are any padding elements in the final texel of the
95+
// packed dimension
96+
const int nspill = mod4(tin_sizes[packed_dim]);
97+
// Detect if this thread is working on the final texels of the packed
98+
// dimension, which may have padding elements
99+
const bool is_last_texel =
100+
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
101+
102+
// Explicitly set padding elements to 0
103+
if (is_last_texel && nspill > 0) {
104+
[[unroll]] for (int i = nspill; i < 4; i++) {
105+
accum[i] = 0;
106+
}
107+
}
108+
scan_pos[reduce_dim1] = 0;
109+
scan_pos[reduce_dim2] = 0;
110+
write_texel(tout, scan_pos, POSTPROCESS(accum));
111+
}
112+
}
113+
114+
void main() {
115+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
116+
scan_pos[reduce_dim1] = 0;
117+
scan_pos[reduce_dim2] = 0;
118+
119+
const ivec2 tid = ivec2(
120+
gl_LocalInvocationID[reduce_dim1],
121+
gl_LocalInvocationID[group_dim]);
122+
123+
if (any(greaterThanEqual(scan_pos, tin_limits))) {
124+
return;
125+
}
126+
127+
reduce_2d(tid, scan_pos);
128+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce2d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
INIT_ACCUM: VEC4_T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: sum2d
20+
- NAME: mean2d
21+
POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2]))
22+
- NAME: amax2d
23+
INIT_ACCUM: first_val
24+
UPDATE_ACCUM: max(accum, new_val)
25+
POSTPROCESS: accum
26+
- NAME: amin2d
27+
INIT_ACCUM: first_val
28+
UPDATE_ACCUM: min(accum, new_val)
29+
POSTPROCESS: accum

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,24 @@ void resize_reduce_node(
3232
out->virtual_resize(new_sizes);
3333
}
3434

35+
void resize_reduce2d_node(
36+
ComputeGraph* graph,
37+
const std::vector<ArgGroup>& args,
38+
const std::vector<ValueRef>& resize_args) {
39+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
40+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
41+
42+
// Extract the dimensions to reduce over
43+
const std::vector<int64_t> dims_list = graph->extract_int_or_symint_list(resize_args.at(0));
44+
int32_t reduce_dim1_nchw = dims_list[0];
45+
int32_t reduce_dim2_nchw = dims_list[1];
46+
47+
std::vector<int64_t> new_sizes = in->sizes();
48+
new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1;
49+
new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1;
50+
out->virtual_resize(new_sizes);
51+
}
52+
3553
utils::uvec3 reduce_global_wg_size(
3654
ComputeGraph* graph,
3755
const vkapi::ShaderInfo& shader,
@@ -137,15 +155,89 @@ void add_reduce_node(
137155
resize_reduce_node));
138156
}
139157

158+
void add_reduce2d_node(
159+
ComputeGraph& graph,
160+
const ValueRef in,
161+
const ValueRef dims_ref,
162+
const ValueRef out,
163+
const std::string& op_name) {
164+
165+
VK_CHECK_COND(
166+
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
167+
"Vulkan reduction only supports texture storage");
168+
169+
const int64_t ndim = graph.dim_of(in);
170+
171+
// Extract the two dimensions to reduce over
172+
const std::vector<int64_t> dims_list = graph.extract_int_or_symint_list(dims_ref);
173+
VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");
174+
175+
int32_t reduce_dim1 = normalize(dims_list[0], ndim);
176+
int32_t reduce_dim2 = normalize(dims_list[1], ndim);
177+
178+
// Convert to WHCN format
179+
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
180+
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);
181+
182+
// Check that the concat dim is not one of the reduction dims
183+
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
184+
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1);
185+
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2);
186+
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1);
187+
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
188+
}
189+
190+
std::string kernel_name = op_name + "2d"; // Add "2d" suffix
191+
kernel_name.reserve(kShaderNameReserve);
192+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
193+
194+
// Calculate group_dim for specialization constants (use remaining dimension)
195+
int32_t group_dim = 0;
196+
for (int i = 0; i < 3; i++) {
197+
if (i != reduce_dim1 && i != reduce_dim2) {
198+
group_dim = i;
199+
break;
200+
}
201+
}
202+
203+
const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1);
204+
const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2);
205+
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);
206+
207+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
208+
graph,
209+
VK_KERNEL_FROM_STR(kernel_name),
210+
reduce_global_wg_size,
211+
reduce_local_wg_size,
212+
// Inputs and Outputs
213+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
214+
// Shader params buffers
215+
{graph.logical_limits_ubo(in), graph.sizes_ubo(in)},
216+
// Push Constants
217+
{},
218+
// Specialization Constants
219+
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
220+
// Resize Args
221+
{dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
222+
// Resizing Logic
223+
resize_reduce2d_node));
224+
}
225+
140226
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
141227
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
142228
const std::vector<int64_t> dims_list = \
143-
graph.extract_int_or_symint_list(args[1]); \
144-
VK_CHECK_COND(dims_list.size() == 1); \
145-
const int64_t dim_val = dims_list.at(0); \
146-
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
147-
return add_reduce_node( \
148-
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
229+
graph.extract_int_or_symint_list(args[1]); \
230+
if (dims_list.size() == 1) { \
231+
const int64_t dim_val = dims_list.at(0); \
232+
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
233+
return add_reduce_node( \
234+
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
235+
} else if (dims_list.size() == 2) { \
236+
return add_reduce2d_node( \
237+
graph, args[0], args[1], args[out_arg_idx], #op_name); \
238+
} else { \
239+
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
240+
} \
149241
}
150242

151243
DEFINE_REDUCE_FN(sum, 4)

0 commit comments

Comments
 (0)