Skip to content

Commit 8c22723

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] aten.var.dim in reduce"
Incorporated variance logic into reduce by adding additional logic Differential Revision: [D75247432](https://our.internmc.facebook.com/intern/diff/D75247432/) [ghstack-poisoned]
2 parents 304fa4a + 579b9cd commit 8c22723

File tree

3 files changed

+19
-30
lines changed

3 files changed

+19
-30
lines changed

backends/vulkan/runtime/graph/ops/glsl/reduce_texture3d.glsl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ $if VARIANCE_MODE:
4646
// work group will write into its assigned element in the shared array.
4747
#define MAX_NTHREADS 16
4848

49-
shared vec4 shared_vecs[MAX_NTHREADS];
49+
shared VEC4_T shared_vecs[MAX_NTHREADS];
5050
// Second accumulator for variance mode - used for sum of values, prev
5151
// accumulator is used for sum of squares
52-
shared vec4 shared_sum_sq[MAX_NTHREADS];
52+
shared VEC4_T shared_sum_sq[MAX_NTHREADS];
5353
shared int shared_count[MAX_NTHREADS];
5454

5555
#include "indexing_utils.h"
@@ -58,9 +58,9 @@ int tid_to_smi(const ivec2 tid) {
5858
return tid.x + tid.y * NWORKERS;
5959
}
6060

61-
vec4 calculate_variance(vec4 sum, vec4 sum_sq, int count) {
62-
vec4 mean = sum / float(count);
63-
vec4 variance = (sum_sq / float(count)) - (mean * mean);
61+
VEC4_T calculate_variance(VEC4_T sum, VEC4_T sum_sq, int count) {
62+
VEC4_T mean = sum / float(count);
63+
VEC4_T variance = (sum_sq / float(count)) - (mean * mean);
6464

6565
if ((pc.unbiased != 0) && (count > 1)) {
6666
variance = variance * (float(count) / float(count - 1.0));
@@ -111,10 +111,10 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
111111
const int smi = tid_to_smi(tid);
112112

113113
scan_pos[reduce_dim] = 0;
114-
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
114+
VEC4_T accum = INIT_ACCUM(load_texel(tin, scan_pos));
115115

116116
#ifdef VARIANCE_MODE
117-
vec4 sum_sq = VEC4_T(0);
117+
VEC4_T sum_sq = VEC4_T(0);
118118
int count = 0;
119119
#endif
120120

@@ -123,7 +123,7 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
123123
// the reduction row
124124
for (int i = tid.x; i < tin_sizes[reduce_dim];
125125
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
126-
vec4 val = load_texel(tin, scan_pos);
126+
VEC4_T val = load_texel(tin, scan_pos);
127127
accum = UPDATE_ACCUM(accum, val);
128128
#ifdef VARIANCE_MODE
129129
sum_sq += val * val;
@@ -166,7 +166,7 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
166166
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
167167

168168
#ifdef VARIANCE_MODE
169-
vec4 variance = calculate_variance(accum, sum_sq, count);
169+
VEC4_T variance = calculate_variance(accum, sum_sq, count);
170170
#endif
171171

172172
// Explicitly set padding elements to 0
@@ -208,10 +208,10 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
208208
const int reduce_len = tin_sizes[packed_dim] - nspill;
209209

210210
scan_pos[reduce_dim] = 0;
211-
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));
211+
VEC4_T accum = INIT_ACCUM(VEC4_T(load_texel(tin, scan_pos).x));
212212

213213
#ifdef VARIANCE_MODE
214-
vec4 sum_sq = VEC4_T(0);
214+
VEC4_T sum_sq = VEC4_T(0);
215215
int count = 0;
216216
#endif
217217

@@ -220,7 +220,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
220220
scan_pos[reduce_dim] = tid.x;
221221
for (int i = tid.x * 4; i < reduce_len;
222222
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
223-
vec4 val = load_texel(tin, scan_pos);
223+
VEC4_T val = load_texel(tin, scan_pos);
224224
accum = UPDATE_ACCUM(accum, val);
225225
#ifdef VARIANCE_MODE
226226
sum_sq += val * val;
@@ -231,7 +231,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
231231
// element of the texel needs to be processed individually such that the
232232
// padding elements are ignored
233233
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
234-
const vec4 val = load_texel(tin, scan_pos);
234+
const VEC4_T val = load_texel(tin, scan_pos);
235235
for (int i = 0; i < nspill; i++) {
236236
accum.x = UPDATE_ACCUM(accum.x, val[i]);
237237
#ifdef VARIANCE_MODE
@@ -280,7 +280,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
280280
}
281281

282282
scan_pos[reduce_dim] = tid.x;
283-
write_texel(tout, scan_pos, vec4(variance, 0, 0, 0));
283+
write_texel(tout, scan_pos, VEC4_T(variance, 0, 0, 0));
284284
#else
285285
// Each element of the texel is itself a partial maximum; iterate over the
286286
// texel to find the actual maximum
@@ -290,7 +290,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
290290
}
291291

292292
scan_pos[reduce_dim] = tid.x;
293-
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
293+
write_texel(tout, scan_pos, POSTPROCESS(VEC4_T(accum_final, 0, 0, 0)));
294294
#endif
295295
}
296296
}

backends/vulkan/runtime/graph/ops/glsl/reduce_texture3d.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) Meta Platforms, Inc. and affiliates.
32
# All rights reserved.
43
#

backends/vulkan/runtime/graph/ops/impl/Reduce.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ void add_reduce_buffer_node(
6060

6161
std::vector<PushConstantDataInfo> push_constants;
6262
int32_t unbiased_int = static_cast<int32_t>(unbiased);
63-
push_constants.emplace_back(
64-
PushConstantDataInfo(&unbiased_int, sizeof(unbiased_int)));
63+
push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int));
6564

6665
graph.execute_nodes().emplace_back(new DispatchNode(
6766
graph,
@@ -137,8 +136,7 @@ void add_reduce_texture_node(
137136

138137
std::vector<PushConstantDataInfo> push_constants;
139138
int32_t unbiased_int = static_cast<int32_t>(unbiased);
140-
push_constants.emplace_back(
141-
PushConstantDataInfo(&unbiased_int, sizeof(unbiased_int)));
139+
push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int));
142140

143141
graph.execute_nodes().emplace_back(new DispatchNode(
144142
graph,
@@ -177,19 +175,11 @@ void add_reduce_node(
177175
}
178176

179177
#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
180-
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
181-
const IntListPtr dims_list = graph.get_int_list(args[1]); \
182-
VK_CHECK_COND(dims_list->size() == 1); \
183-
return add_reduce_node( \
184-
graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \
185-
}
186-
187-
#define DEFINE_VAR_FN(op_name, out_arg_idx) \
188178
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
189179
const IntListPtr dims_list = graph.get_int_list(args[1]); \
190180
VK_CHECK_COND(dims_list->size() == 1); \
191181
bool unbiased = false; \
192-
if (args.size() > 2) { \
182+
if (strcmp(#op_name, "var") == 0 && args.size() > 2) { \
193183
unbiased = graph.get_bool(args[2]); \
194184
} \
195185
return add_reduce_node( \
@@ -205,7 +195,7 @@ DEFINE_REDUCE_FN(sum, 4)
205195
DEFINE_REDUCE_FN(mean, 4)
206196
DEFINE_REDUCE_FN(amax, 3)
207197
DEFINE_REDUCE_FN(amin, 3)
208-
DEFINE_VAR_FN(var, 4)
198+
DEFINE_REDUCE_FN(var, 4)
209199

210200
REGISTER_OPERATORS {
211201
VK_REGISTER_OP(aten.sum.dim_IntList, sum);

0 commit comments

Comments
 (0)