Skip to content

Commit 4292c64

Browse files
author
ssjia
committed
[ET-VK] Introduce specialized implementation for per-row reduction
Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/) ghstack-source-id: 316415599 Pull Request resolved: #15161
1 parent 9e80387 commit 4292c64

File tree

8 files changed

+499
-0
lines changed

8 files changed

+499
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
449449
return False
450450

451451
keepdim = try_find_keepdim_arg(node)
452+
# keepdim = False is not supported yet
452453
if isinstance(keepdim, bool) and not keepdim:
453454
return False
454455

@@ -461,6 +462,15 @@ def pick_io_storage_for_reduce(node: torch.fx.Node):
461462
input_tensor = node.args[0]
462463
ndim = input_tensor.meta["val"].ndim
463464
dim_list = node.args[1]
465+
466+
# For 1D reductions, a special case is implemented for reducing the width dim
467+
if isinstance(dim_list, list) and len(dim_list) == 1:
468+
if dim_list[0] == -1:
469+
inputs_storage = utils.ANY_TEXTURE.make_union(utils.CONTIGUOUS_BUFFER)
470+
outputs_storage = inputs_storage
471+
return inputs_storage, outputs_storage
472+
473+
# For 2D reductions, the packed dimension cannot be one of the reduced dims
464474
if isinstance(dim_list, list) and len(dim_list) == 2:
465475
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
466476
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef REDUCE_OP_DEFS_GLSLH
10+
#define REDUCE_OP_DEFS_GLSLH
11+
12+
struct Accum {
13+
T val;
14+
uint idx;
15+
uint count;
16+
};
17+
18+
void init_accum(out Accum accum, T val, uint idx) {
19+
accum.val = val;
20+
accum.idx = idx;
21+
accum.count = 1;
22+
}
23+
24+
void init_accum_zero(out Accum accum) {
25+
accum.val = T(0);
26+
accum.idx = 0;
27+
accum.count = 0;
28+
}
29+
30+
// Sum / Mean
31+
32+
void update_accum_sum(inout Accum accum, T val, uint idx) {
33+
accum.val += val;
34+
accum.count += 1;
35+
}
36+
37+
void merge_accum_sum(inout Accum accum, const Accum other) {
38+
accum.val += other.val;
39+
accum.count += other.count;
40+
}
41+
42+
void postprocess_accum_mean(inout Accum accum) {
43+
accum.val /= T(accum.count);
44+
}
45+
46+
// Amax (maximum value)
47+
48+
void update_accum_amax(inout Accum accum, T val, uint idx) {
49+
if (val > accum.val) {
50+
accum.val = val;
51+
accum.idx = idx;
52+
}
53+
// For equivalence, select the lower index
54+
if (val == accum.val && idx < accum.idx) {
55+
accum.idx = idx;
56+
}
57+
}
58+
59+
void merge_accum_amax(inout Accum accum, const Accum other) {
60+
if (other.val > accum.val) {
61+
accum.val = other.val;
62+
accum.idx = other.idx;
63+
}
64+
// For equivalence, select the lower index
65+
if (other.val == accum.val && other.idx < accum.idx) {
66+
accum.idx = other.idx;
67+
}
68+
}
69+
70+
// Amin (minimum value)
71+
72+
void update_accum_amin(inout Accum accum, T val, uint idx) {
73+
if (val < accum.val) {
74+
accum.val = val;
75+
accum.idx = idx;
76+
}
77+
// For equivalence, select the lower index
78+
if (val == accum.val && idx < accum.idx) {
79+
accum.idx = idx;
80+
}
81+
}
82+
83+
void merge_accum_amin(inout Accum accum, const Accum other) {
84+
if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) {
85+
accum.val = other.val;
86+
accum.idx = other.idx;
87+
}
88+
// For equivalence, select the lower index
89+
if (other.val == accum.val && other.idx < accum.idx) {
90+
accum.idx = other.idx;
91+
}
92+
}
93+
94+
#endif // REDUCE_OP_DEFS_GLSLH
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${texel_load_component_type(DTYPE, "buffer")}
14+
15+
#define NUM_OUTPUTS_PER_WG 1
16+
#define NUM_WORKERS_PER_OUTPUT 64
17+
18+
${define_active_storage_type("buffer")}
19+
${define_required_extensions(DTYPE)}
20+
21+
#extension GL_EXT_control_flow_attributes : require
22+
23+
layout(std430) buffer;
24+
25+
#include "indexing.glslh"
26+
#include "reduce_op_defs.glslh"
27+
28+
$if OUTPUT_IS_INDICES:
29+
${layout_declare_tensor(B, "w", "t_out", "int", "buffer")}
30+
$else:
31+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
32+
33+
${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")}
34+
35+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
36+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
37+
38+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
39+
40+
// Shared memory for cooperative reduction
41+
shared Accum shared_values[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
42+
43+
#define init_fn ${INIT_ACCUM_FN}
44+
#define update_fn ${UPDATE_ACCUM_FN}
45+
#define merge_fn ${MERGE_ACCUM_FN}
46+
47+
$if POSTPROCESS_ACCUM_FN != "none":
48+
#define postprocess_fn ${POSTPROCESS_ACCUM_FN}
49+
50+
$if OOB_INIT_MODE == "zero":
51+
#define OOB_INIT_MODE 0
52+
$else:
53+
#define OOB_INIT_MODE 1
54+
55+
$if OUTPUT_IS_INDICES:
56+
#define OUTPUT_IS_INDICES
57+
58+
#extension GL_EXT_debug_printf : require
59+
60+
void main() {
61+
const uint out_bufi = gl_GlobalInvocationID.y;
62+
63+
if (out_of_bounds(out_bufi, outp)) {
64+
return;
65+
}
66+
67+
// Local indices
68+
const uint worker_id = gl_LocalInvocationID.x;
69+
const uint output_id = gl_LocalInvocationID.y;
70+
71+
const uint in_bufi_base = out_bufi * width(inp);
72+
73+
Accum local_accum;
74+
// Initialize accumulator with the first element being processed
75+
if (worker_id < width(inp)) {
76+
const uint in_bufi = in_bufi_base + worker_id;
77+
init_fn(local_accum, t_in[in_bufi], worker_id);
78+
}
79+
// For out of bounds case, initialization depends on reduction op
80+
else {
81+
#if OOB_INIT_MODE == 0
82+
// Init with a zero value
83+
init_accum_zero(local_accum);
84+
#else
85+
// Init with the first value (i.e. amin, amax)
86+
init_fn(local_accum, t_in[in_bufi_base], 0);
87+
#endif
88+
}
89+
90+
for (uint x = worker_id + NUM_WORKERS_PER_OUTPUT; x < width(inp);
91+
x += NUM_WORKERS_PER_OUTPUT) {
92+
update_fn(local_accum, t_in[in_bufi_base + x], x);
93+
}
94+
95+
shared_values[output_id][worker_id] = local_accum;
96+
97+
memoryBarrierShared();
98+
barrier();
99+
100+
for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) {
101+
if (worker_id < i) {
102+
merge_fn(
103+
shared_values[output_id][worker_id],
104+
shared_values[output_id][worker_id + i]);
105+
}
106+
memoryBarrierShared();
107+
barrier();
108+
}
109+
110+
if (worker_id == 0) {
111+
local_accum = shared_values[output_id][0];
112+
#ifdef postprocess_fn
113+
postprocess_fn(local_accum);
114+
#endif
115+
116+
#ifdef OUTPUT_IS_INDICES
117+
t_out[out_bufi] = int(0); // int(local_accum.idx);
118+
#else
119+
t_out[out_bufi] = local_accum.val;
120+
#endif
121+
}
122+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce_per_row_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
INIT_ACCUM_FN: init_accum
11+
UPDATE_ACCUM_FN: update_accum_sum
12+
MERGE_ACCUM_FN: merge_accum_sum
13+
POSTPROCESS_ACCUM_FN: none
14+
OOB_INIT_MODE: zero
15+
OUTPUT_IS_INDICES: false
16+
generate_variant_forall:
17+
DTYPE:
18+
- VALUE: float
19+
- VALUE: half
20+
- VALUE: int32
21+
shader_variants:
22+
- NAME: sum_per_row_buffer
23+
- NAME: mean_per_row_buffer
24+
POSTPROCESS_ACCUM_FN: postprocess_accum_mean
25+
- NAME: amax_per_row_buffer
26+
UPDATE_ACCUM_FN: update_accum_amax
27+
MERGE_ACCUM_FN: merge_accum_amax
28+
OOB_INIT_MODE: first_element
29+
- NAME: amin_per_row_buffer
30+
UPDATE_ACCUM_FN: update_accum_amin
31+
MERGE_ACCUM_FN: merge_accum_amin
32+
OOB_INIT_MODE: first_element
33+
- NAME: argmax_per_row_buffer
34+
UPDATE_ACCUM_FN: update_accum_amax
35+
MERGE_ACCUM_FN: merge_accum_amax
36+
OOB_INIT_MODE: first_element
37+
OUTPUT_IS_INDICES: true
38+
- NAME: argmin_per_row_buffer
39+
UPDATE_ACCUM_FN: update_accum_amin
40+
MERGE_ACCUM_FN: merge_accum_amin
41+
OOB_INIT_MODE: first_element
42+
OUTPUT_IS_INDICES: true
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Reduce.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
namespace vkcompute {
17+
18+
void arg_reduce_impl(
19+
ComputeGraph& graph,
20+
const std::vector<ValueRef>& args,
21+
const std::string& op_name) {
22+
int arg_idx = 0;
23+
const ValueRef in = args.at(arg_idx++);
24+
const ValueRef dim = args.at(arg_idx++);
25+
const ValueRef keepdim = args.at(arg_idx++);
26+
const ValueRef out = args.at(arg_idx++);
27+
28+
VK_CHECK_COND(graph.is_buffer_storage(in));
29+
30+
int64_t dim_val = 0;
31+
if (graph.val_is_not_none(dim)) {
32+
dim_val = graph.extract_scalar<int64_t>(dim);
33+
}
34+
const int64_t ndim = graph.dim_of(in);
35+
const int64_t normalized_dim = normalize(dim_val, graph.dim_of(in));
36+
37+
VK_CHECK_COND(normalized_dim == ndim - 1);
38+
39+
// Use the reduce_per_row_node function
40+
add_reduce_per_row_node(graph, in, out, op_name);
41+
}
42+
43+
void argmin(ComputeGraph& graph, const std::vector<ValueRef>& args) {
44+
arg_reduce_impl(graph, args, "argmin");
45+
}
46+
47+
void argmax(ComputeGraph& graph, const std::vector<ValueRef>& args) {
48+
arg_reduce_impl(graph, args, "argmax");
49+
}
50+
51+
REGISTER_OPERATORS {
52+
VK_REGISTER_OP(aten.argmin.default, argmin);
53+
VK_REGISTER_OP(aten.argmax.default, argmax);
54+
}
55+
56+
} // namespace vkcompute

0 commit comments

Comments
 (0)