Skip to content

Commit f63fa21

Browse files
committed
Update on "[ET-VK] Implement generic reduction shader + mean, sum, amax, amin"
## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) [ghstack-poisoned]
2 parents a3f0b00 + ca23a9f commit f63fa21

File tree

9 files changed

+14
-413
lines changed

9 files changed

+14
-413
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ def __contains__(self, op):
105105
]
106106

107107
NO_DYNAMIC_SHAPE = [
108-
# Reduction
109-
exir_ops.edge.aten.mean.dim,
110-
exir_ops.edge.aten.sum.dim_IntList,
111108
# Normalization
112109
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
113110
exir_ops.edge.aten.native_layer_norm.default,

backends/vulkan/runtime/graph/ops/glsl/reduce.glsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
109109
// Iterate over the partial outputs to obtain the overall output
110110
int group_i = tid.y * NWORKERS;
111111
accum = shared_vecs[group_i++];
112-
for (int i = 1; i < NWORKERS; ++i, group_i++) {
112+
for (int i = 1; i < NWORKERS; i++, group_i++) {
113113
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
114114
}
115115

@@ -123,7 +123,7 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
123123

124124
// Explicitly set padding elements to 0
125125
if (is_last_texel && nspill > 0) {
126-
[[unroll]] for (int i = nspill; i < 4; ++i) {
126+
[[unroll]] for (int i = nspill; i < 4; i++) {
127127
accum[i] = 0;
128128
}
129129
}
@@ -165,7 +165,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
165165
// padding elements are ignored
166166
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
167167
const vec4 intex = load_texel(tin, scan_pos);
168-
for (int i = 0; i < nspill; ++i) {
168+
for (int i = 0; i < nspill; i++) {
169169
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
170170
}
171171
}
@@ -179,13 +179,13 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
179179
// Iterate over the partial maximums to obtain the overall maximum
180180
int group_i = tid.y * NWORKERS;
181181
accum = shared_vecs[group_i++];
182-
for (int i = 1; i < NWORKERS; ++i, group_i++) {
182+
for (int i = 1; i < NWORKERS; i++, group_i++) {
183183
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
184184
}
185185
// Each element of the texel is itself a partial maximum; iterate over the
186186
// texel to find the actual maximum
187187
float accum_final = accum.x;
188-
[[unroll]] for (int i = 1; i < 4; ++i) {
188+
[[unroll]] for (int i = 1; i < 4; i++) {
189189
accum_final = UPDATE_ACCUM(accum[i], accum_final);
190190
}
191191

backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl

Lines changed: 0 additions & 108 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl

Lines changed: 0 additions & 95 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ void resize_reduce_node(
2525
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2626
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
2727

28-
std::vector<int64_t> in_sizes = in->sizes();
29-
// out->virtual_resize(in_sizes);
28+
int dim = extra_args[0];
29+
30+
std::vector<int64_t> new_sizes = in->sizes();
31+
new_sizes[normalize(dim, new_sizes.size())] = 1;
32+
out->virtual_resize(new_sizes);
3033
}
3134

3235
void add_reduce_node(
@@ -48,12 +51,8 @@ void add_reduce_node(
4851
// Check that the concat dim is not the reduction dim, if the tensor has a
4952
// batch dim greater than 1.
5053
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
51-
VK_CHECK_COND(
52-
graph.concat_dim_of(in) != reduce_dim,
53-
"Reduce shader currently does not support concat dim == reduce dim");
54-
VK_CHECK_COND(
55-
graph.concat_dim_of(out) != reduce_dim,
56-
"Reduce shader currently does not support concat dim == reduce dim");
54+
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim);
55+
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim);
5756
}
5857

5958
vkapi::ShaderInfo shader_descriptor;
@@ -97,7 +96,8 @@ void add_reduce_node(
9796
// Specialization Constants
9897
{graph.packed_dim_of(out), reduce_dim, group_dim},
9998
// Resizing Logic
100-
resize_reduce_node));
99+
resize_reduce_node,
100+
{dim}));
101101
}
102102

103103
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \

0 commit comments

Comments
 (0)