Skip to content

Commit 4de8d6e

Browse files
committed
[ET-VK] Implement `native_group_norm
## Changes * Add implementation for the group norm operator. The operator is implemented via a 2 stage implementation. First, a reduction operator is executed to calculate the mean and standard deviation of each channel group. Then, the normalization is applied in an elementwise fashion. Differential Revision: [D77038778](https://our.internmc.facebook.com/intern/diff/D77038778/) [ghstack-poisoned]
1 parent 12f6ad5 commit 4de8d6e

File tree

8 files changed

+674
-4
lines changed

8 files changed

+674
-4
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
272272
VK_THROW("Could not get dtype of value with type ", val.type());
273273
}
274274

275+
bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
276+
if (!val_is_tensor(idx)) {
277+
return false;
278+
}
279+
if (!is_buffer_storage(idx)) {
280+
return false;
281+
}
282+
return is_contiguous(idx);
283+
}
284+
285+
bool ComputeGraph::is_standard_channels_packed_texture_tensor(
286+
const ValueRef idx) const {
287+
if (!val_is_tensor(idx)) {
288+
return false;
289+
}
290+
if (is_buffer_storage(idx)) {
291+
return false;
292+
}
293+
return has_standard_axis_map(idx) && packed_dim_of(idx) == 2;
294+
}
295+
296+
bool ComputeGraph::is_standard_width_packed_texture_tensor(
297+
const ValueRef idx) const {
298+
if (!val_is_tensor(idx)) {
299+
return false;
300+
}
301+
if (is_buffer_storage(idx)) {
302+
return false;
303+
}
304+
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
305+
}
306+
275307
ValueRef ComputeGraph::add_tensor(
276308
const std::vector<int64_t>& sizes,
277309
const vkapi::ScalarType dtype,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class ComputeGraph final {
231231
inline ptr_type get_##short_name(const ValueRef idx) { \
232232
return ptr_type(this, idx); \
233233
} \
234-
inline bool val_is_##short_name(const ValueRef idx) { \
234+
inline bool val_is_##short_name(const ValueRef idx) const { \
235235
return values_.at(idx).is##type_name(); \
236236
}
237237

@@ -314,6 +314,32 @@ class ComputeGraph final {
314314
return values_.at(idx).toConstTensor().has_buffer_storage();
315315
}
316316

317+
/*
318+
* Checks that the following is true:
319+
* 1. The value at `idx` is a tensor
320+
* 2. The tensor at `idx` has buffer storage
321+
* 3. The buffer backed tensor at `idx` has a contiguous memory layout
322+
*/
323+
bool is_contiguous_buffer_tensor(const ValueRef idx) const;
324+
325+
/*
326+
* Checks that the following is true:
327+
* 1. The value at `idx` is a tensor
328+
* 2. The tensor at `idx` has texture storage
329+
* 3. The texture backed tensor at `idx` has a standard axis mapping
330+
* 4. The texture backed tensor at `idx` is channels packed
331+
*/
332+
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;
333+
334+
/*
335+
* Checks that the following is true:
336+
* 1. The value at `idx` is a tensor
337+
* 2. The tensor at `idx` has texture storage
338+
* 3. The texture backed tensor at `idx` has a standard axis mapping
339+
* 4. The texture backed tensor at `idx` is width packed
340+
*/
341+
bool is_standard_width_packed_texture_tensor(const ValueRef idx) const;
342+
317343
inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
318344
const {
319345
return values_.at(maybe_view)
@@ -354,7 +380,7 @@ class ComputeGraph final {
354380
return values_.at(idx).toTensor().numel_ubo();
355381
}
356382

357-
inline bool has_standard_axis_map(const ValueRef idx) {
383+
inline bool has_standard_axis_map(const ValueRef idx) const {
358384
return values_.at(idx).toTensor().has_standard_axis_map();
359385
}
360386

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${texel_type(DTYPE)}
17+
18+
#define T ${buffer_scalar_type(DTYPE)}
19+
20+
${define_required_extensions(DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")}
25+
${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")}
26+
27+
${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")}
28+
29+
${layout_declare_ubo(B, "ivec4", "mean_strides")}
30+
${layout_declare_ubo(B, "int", "mean_numel")}
31+
${layout_declare_ubo(B, "ivec3", "in_limits")}
32+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
33+
34+
layout(push_constant) uniform PRECISION restrict Block {
35+
int group;
36+
float epsilon;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")}
42+
const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout);
43+
44+
#define LOCAL_WORK_GROUP_SIZE 64
45+
shared float shared_sum[LOCAL_WORK_GROUP_SIZE];
46+
shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE];
47+
48+
/*
49+
* Computes the mean and standard deviation of one group of channels of the
50+
* input tensor for the group normalization operator.
51+
*
52+
* Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors
53+
* will have a shape of [G, N] where G = C / group.
54+
*
55+
* The input tensor is assumed to be a channels-packed texture tensor with the
56+
* standard axis mapping. The output tensors are assumed to be contiguous buffer
57+
* tensors.
58+
*
59+
* Algorithm:
60+
* 1. Each shader invocation corresponds to one group in one batch
61+
* 2. The local work group cooperatively reduces over all spatial locations (H×W)
62+
* and all channels within the group (C/group channels)
63+
* 3. Uses shared memory for efficient parallel reduction
64+
* 4. Main thread (local ID 0) writes the final mean and rstd to buffer
65+
*
66+
* Global work group size: {N, 1, 1}
67+
* N is the number of elements in the tensor buffer; each thread computes one
68+
* output element.
69+
*
70+
* Local work group size: {1, T, 1}
71+
* T should be a power of 2, recommended 64 or 128 threads. This allows
72+
* efficient tree-based reduction in shared memory. Each local group will
73+
* cooperate to compute the output element.
74+
*
75+
* Each shader invocation will compute the mean and standard deviation for one
76+
* channel group in the input, and write out the corresponding result.
77+
*/
78+
void group_norm_reduce_C_packed() {
79+
const int global_idx = int(gl_GlobalInvocationID.x);
80+
const int local_idx = int(gl_LocalInvocationID.y);
81+
82+
// Calculate group dimensions
83+
const int D = in_sizes.z / group; // channels per group
84+
const int HxW = in_sizes.y * in_sizes.x; // spatial size
85+
const int group_size = D * HxW; // total elements per group
86+
87+
// Convert global index to (group_idx, batch_idx)
88+
const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order);
89+
90+
// Initialize local sums
91+
float local_sum = 0.0;
92+
float local_sum_sq = 0.0;
93+
int local_count = 0;
94+
95+
// Calculate the range of channels for this group
96+
const int group_start_channel = mean_tidx.x * D;
97+
const int group_end_channel = group_start_channel + D;
98+
99+
// Calculate the range of texels that contain channels from this group
100+
const int start_texel_idx = group_start_channel / 4;
101+
const int end_texel_idx = divup4(group_end_channel);
102+
const int texels_in_group = end_texel_idx - start_texel_idx;
103+
104+
// Total texels to process across all spatial locations
105+
const int total_texels = texels_in_group * HxW;
106+
107+
// Each thread processes a subset of texels
108+
const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE;
109+
const int start_texel = local_idx * texels_per_thread;
110+
const int end_texel = min(start_texel + texels_per_thread, total_texels);
111+
112+
// Process assigned texels
113+
for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) {
114+
// Convert texel index to spatial and channel coordinates
115+
const int spatial_idx = texel_idx / texels_in_group;
116+
const int texel_in_group = texel_idx % texels_in_group;
117+
118+
// Convert to spatial coordinates
119+
const int w = spatial_idx % in_sizes.x;
120+
const int h = spatial_idx / in_sizes.x;
121+
122+
// Calculate the global texel index
123+
const int global_texel_idx = start_texel_idx + texel_in_group;
124+
125+
// Convert to texture position using default axis mapping
126+
ivec3 tex_pos = ivec3(w, h, global_texel_idx);
127+
128+
// Adjust for batch dimension if needed
129+
if (in_sizes.w > 1) {
130+
// default axis mapping means channels is the batch concat dim
131+
tex_pos.z += mean_tidx.y * divup4(in_sizes.z);
132+
}
133+
134+
// Check bounds and load texel
135+
if (all(lessThan(tex_pos, in_limits))) {
136+
const VEC4_T texel_val = load_texel(t_in, tex_pos);
137+
138+
// Process all components of the texel that belong to this group
139+
const int texel_start_channel = global_texel_idx * 4;
140+
for (int comp = 0; comp < 4; comp++) {
141+
const int current_channel = texel_start_channel + comp;
142+
143+
// Check if this component belongs to the current group
144+
if (current_channel >= group_start_channel && current_channel < group_end_channel) {
145+
const float val = texel_val[comp];
146+
local_sum += val;
147+
local_sum_sq += val * val;
148+
local_count++;
149+
}
150+
}
151+
}
152+
}
153+
154+
// Store local results in shared memory
155+
shared_sum[local_idx] = local_sum;
156+
shared_sum_sq[local_idx] = local_sum_sq;
157+
158+
// Synchronize threads
159+
memoryBarrierShared();
160+
barrier();
161+
162+
// Perform tree-based reduction in shared memory
163+
for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) {
164+
if (local_idx < stride) {
165+
shared_sum[local_idx] += shared_sum[local_idx + stride];
166+
shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride];
167+
}
168+
memoryBarrierShared();
169+
barrier();
170+
}
171+
172+
// Main thread writes the result
173+
if (local_idx == 0 && global_idx < mean_numel) {
174+
const float total_sum = shared_sum[0];
175+
const float total_sum_sq = shared_sum_sq[0];
176+
const float count = float(group_size);
177+
178+
// Calculate mean and reciprocal standard deviation
179+
const float mean_val = total_sum / count;
180+
const float variance = (total_sum_sq / count) - (mean_val * mean_val);
181+
const float rstd_val = 1.0 / sqrt(variance + epsilon);
182+
183+
// Write to buffer-backed tensors
184+
t_mean[global_idx] = T(mean_val);
185+
t_rstd[global_idx] = T(rstd_val);
186+
}
187+
}
188+
189+
void main() {
190+
group_norm_reduce_C_packed();
191+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
group_norm_reduce_texture:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: group_norm_reduce_texture

0 commit comments

Comments
 (0)