Skip to content

Commit 112a09f

Browse files
[ET-VK] Add 2D Reduction to Vulkan Backend (#12860)
Summary: This change adds 2D reduction to the Vulkan delegate. Prior to this change, only 1D reduction was implemented. Models like MobileNetV3 and ResNet do 2D reduction, and their performance was being negatively impacted by the lack of a 2D reduction Vulkan implementation. cc @SS-JIA @manuelcandales @cbilgin --------- Co-authored-by: Mateusz Sluszniak <[email protected]>
1 parent 0a2bf93 commit 112a09f

File tree

4 files changed

+304
-6
lines changed

4 files changed

+304
-6
lines changed

backends/vulkan/op_registry.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import torch
1818

19+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
20+
1921
from executorch.exir.dialects._ops import ops as exir_ops
2022

2123
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -373,7 +375,41 @@ def register_softmax_op():
373375
def register_reduce_op():
374376
def check_reduce_node(node: torch.fx.Node) -> bool:
375377
dim_list = node.args[1]
376-
if isinstance(dim_list, list) and len(dim_list) != 1:
378+
if isinstance(dim_list, list) and len(dim_list) > 2:
379+
return False
380+
381+
if isinstance(dim_list, list) and len(dim_list) == 2:
382+
# Try to get the memory layout for this node
383+
try:
384+
memory_layout = utils.get_node_memory_layout(node)
385+
386+
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
387+
if memory_layout is not None:
388+
for dim in dim_list:
389+
# For WIDTH_PACKED layout, dimension 3 (W) is packed
390+
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
391+
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
392+
if (
393+
(
394+
memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED
395+
and dim == 3
396+
)
397+
or (
398+
memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED
399+
and dim == 2
400+
)
401+
or (
402+
memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED
403+
and dim == 1
404+
)
405+
):
406+
return False
407+
except (AssertionError, KeyError, AttributeError):
408+
# If we can't get memory layout information, we'll assume the dims aren't packed
409+
pass
410+
411+
keepdim = node.args[2]
412+
if isinstance(keepdim, bool) and not keepdim:
377413
return False
378414

379415
if len(node.args) > 2:
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
${define_active_storage_type(STORAGE)}
15+
16+
#extension GL_EXT_control_flow_attributes : require
17+
18+
layout(std430) buffer;
19+
20+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
22+
23+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
24+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
25+
26+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
27+
28+
layout(constant_id = 3) const int packed_dim = 0;
29+
layout(constant_id = 4) const int reduce_dim1 = 0;
30+
layout(constant_id = 5) const int reduce_dim2 = 1;
31+
layout(constant_id = 6) const int group_dim = 2;
32+
33+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
34+
// threads that will co-operate to compute one reduction output. There may be
35+
// multiple groups computing distinct reduction outputs within one work group.
36+
#define NWORKERS 4
37+
38+
// Sets an upper limit on the total size of a work group based on how many
39+
// elements are allocated in the shared memory array below. Each thread in the
40+
// work group will write into its assigned element in the shared array.
41+
#define MAX_NTHREADS 16
42+
43+
44+
shared vec4 shared_vecs[MAX_NTHREADS];
45+
46+
#include "indexing_utils.h"
47+
48+
int tid_to_smi(const ivec2 tid) {
49+
return tid.x + tid.y * NWORKERS;
50+
}
51+
52+
// Initializing the accumulator accepts the first value in the reduction row,
53+
// since some reduction operations (i.e. amax, amin) prefer to initialize with
54+
// a data point instead of a static value.
55+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
56+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
57+
// Useful for operators such as mean which want to perform a final calculation
58+
// with the accumulator.
59+
#define POSTPROCESS(accum) ${POSTPROCESS}
60+
61+
void reduce_2d_non_packed_dim(const ivec2 tid, ivec3 scan_pos) {
62+
// shared memory index of this thread
63+
const int smi = tid_to_smi(tid);
64+
65+
scan_pos[reduce_dim1] = 0;
66+
scan_pos[reduce_dim2] = 0;
67+
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
68+
69+
// First dimension reduction
70+
scan_pos[reduce_dim1] = tid.x;
71+
for (int i = tid.x; i < tin_sizes[reduce_dim1];
72+
i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) {
73+
74+
// Second dimension reduction
75+
scan_pos[reduce_dim2] = 0;
76+
for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) {
77+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
78+
}
79+
}
80+
81+
// Write partial output to shared memory and synchronize
82+
shared_vecs[smi] = accum;
83+
barrier();
84+
85+
// Main thread aggregates results
86+
if (tid.x == 0) {
87+
// Iterate over the partial outputs to obtain the overall output
88+
int group_i = tid.y * NWORKERS;
89+
accum = shared_vecs[group_i++];
90+
for (int i = 1; i < NWORKERS; i++, group_i++) {
91+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
92+
}
93+
94+
// Determine if there are any padding elements in the final texel of the
95+
// packed dimension
96+
const int nspill = mod4(tin_sizes[packed_dim]);
97+
// Detect if this thread is working on the final texels of the packed
98+
// dimension, which may have padding elements
99+
const bool is_last_texel =
100+
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
101+
102+
// Explicitly set padding elements to 0
103+
if (is_last_texel && nspill > 0) {
104+
[[unroll]] for (int i = nspill; i < 4; i++) {
105+
accum[i] = 0;
106+
}
107+
}
108+
scan_pos[reduce_dim1] = 0;
109+
scan_pos[reduce_dim2] = 0;
110+
write_texel(tout, scan_pos, POSTPROCESS(accum));
111+
}
112+
}
113+
114+
void main() {
115+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
116+
scan_pos[reduce_dim1] = 0;
117+
scan_pos[reduce_dim2] = 0;
118+
119+
const ivec2 tid = ivec2(
120+
gl_LocalInvocationID[reduce_dim1],
121+
gl_LocalInvocationID[group_dim]);
122+
123+
if (any(greaterThanEqual(scan_pos, tin_limits))) {
124+
return;
125+
}
126+
127+
reduce_2d_non_packed_dim(tid, scan_pos);
128+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce2d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
INIT_ACCUM: VEC4_T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: sum2d
20+
- NAME: mean2d
21+
POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2]))
22+
- NAME: amax2d
23+
INIT_ACCUM: first_val
24+
UPDATE_ACCUM: max(accum, new_val)
25+
POSTPROCESS: accum
26+
- NAME: amin2d
27+
INIT_ACCUM: first_val
28+
UPDATE_ACCUM: min(accum, new_val)
29+
POSTPROCESS: accum

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ void resize_reduce_node(
3333
graph->virtual_resize(out, new_sizes);
3434
}
3535

36+
void resize_reduce2d_node(
37+
ComputeGraph* graph,
38+
const std::vector<ArgGroup>& args,
39+
const std::vector<ValueRef>& resize_args) {
40+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
41+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
42+
43+
// Extract the dimensions to reduce over
44+
const std::vector<int64_t> dims_list =
45+
graph->extract_int_or_symint_list(resize_args.at(0));
46+
int32_t reduce_dim1_nchw = dims_list[0];
47+
int32_t reduce_dim2_nchw = dims_list[1];
48+
49+
std::vector<int64_t> new_sizes = in->sizes();
50+
new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1;
51+
new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1;
52+
out->virtual_resize(new_sizes);
53+
}
54+
3655
utils::uvec3 reduce_global_wg_size(
3756
ComputeGraph* graph,
3857
const vkapi::ShaderInfo& shader,
@@ -138,15 +157,101 @@ void add_reduce_node(
138157
resize_reduce_node));
139158
}
140159

160+
void add_reduce2d_node(
161+
ComputeGraph& graph,
162+
const ValueRef in,
163+
const ValueRef dims_ref,
164+
const ValueRef out,
165+
const std::string& op_name) {
166+
VK_CHECK_COND(
167+
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
168+
"Vulkan reduction only supports texture storage");
169+
170+
const int64_t ndim = graph.dim_of(in);
171+
172+
// Extract the two dimensions to reduce over
173+
const std::vector<int64_t> dims_list =
174+
graph.extract_int_or_symint_list(dims_ref);
175+
VK_CHECK_COND(
176+
dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");
177+
178+
int32_t reduce_dim1 = normalize(dims_list[0], ndim);
179+
int32_t reduce_dim2 = normalize(dims_list[1], ndim);
180+
181+
// Convert to WHCN format
182+
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
183+
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);
184+
185+
// Check that none of the reduction dims are packed
186+
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1);
187+
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2);
188+
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim1);
189+
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim2);
190+
191+
// Check that the concat dim is not one of the reduction dims
192+
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
193+
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1);
194+
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2);
195+
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1);
196+
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
197+
}
198+
199+
std::string kernel_name = op_name + "2d"; // Add "2d" suffix
200+
kernel_name.reserve(kShaderNameReserve);
201+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
202+
203+
// Calculate group_dim for specialization constants (use remaining dimension)
204+
int32_t group_dim = 0;
205+
for (int i = 0; i < 3; i++) {
206+
if (i != reduce_dim1 && i != reduce_dim2) {
207+
group_dim = i;
208+
break;
209+
}
210+
}
211+
212+
const ValueRef reduce_dim1_whcn_ref =
213+
graph.get_or_add_value_for_int(reduce_dim1);
214+
const ValueRef reduce_dim2_whcn_ref =
215+
graph.get_or_add_value_for_int(reduce_dim2);
216+
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);
217+
218+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
219+
graph,
220+
VK_KERNEL_FROM_STR(kernel_name),
221+
reduce_global_wg_size,
222+
reduce_local_wg_size,
223+
// Inputs and Outputs
224+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
225+
// Shader params buffers
226+
{graph.logical_limits_ubo(in), graph.sizes_ubo(in)},
227+
// Push Constants
228+
{},
229+
// Specialization Constants
230+
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
231+
// Resize Args
232+
{dims_ref,
233+
reduce_dim1_whcn_ref,
234+
reduce_dim2_whcn_ref,
235+
group_dim_whcn_ref},
236+
// Resizing Logic
237+
resize_reduce2d_node));
238+
}
239+
141240
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
142241
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
143242
const std::vector<int64_t> dims_list = \
144243
graph.extract_int_or_symint_list(args[1]); \
145-
VK_CHECK_COND(dims_list.size() == 1); \
146-
const int64_t dim_val = dims_list.at(0); \
147-
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
148-
return add_reduce_node( \
149-
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
244+
if (dims_list.size() == 1) { \
245+
const int64_t dim_val = dims_list.at(0); \
246+
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
247+
return add_reduce_node( \
248+
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
249+
} \
250+
if (dims_list.size() == 2) { \
251+
return add_reduce2d_node( \
252+
graph, args[0], args[1], args[out_arg_idx], #op_name); \
253+
} \
254+
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
150255
}
151256

152257
DEFINE_REDUCE_FN(sum, 4)

0 commit comments

Comments
 (0)