Skip to content

Commit 8d12f03

Browse files
pytorchbothinriksnaer
authored andcommitted
[ET-VK] Implement `native_group_norm (pytorch#11973)
## Changes * Add implementation for the group norm operator. The operator is implemented via a 2 stage implementation. First, a reduction operator is executed to calculate the mean and standard deviation of each channel group. Then, the normalization is applied in an elementwise fashion. Differential Revision: [D77038778](https://our.internmc.facebook.com/intern/diff/D77038778/)
1 parent 887f82d commit 8d12f03

File tree

8 files changed

+672
-2
lines changed

8 files changed

+672
-2
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
272272
VK_THROW("Could not get dtype of value with type ", val.type());
273273
}
274274

275+
bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
276+
if (!val_is_tensor(idx)) {
277+
return false;
278+
}
279+
if (!is_buffer_storage(idx)) {
280+
return false;
281+
}
282+
return is_contiguous(idx);
283+
}
284+
285+
bool ComputeGraph::is_standard_channels_packed_texture_tensor(
286+
const ValueRef idx) const {
287+
if (!val_is_tensor(idx)) {
288+
return false;
289+
}
290+
if (is_buffer_storage(idx)) {
291+
return false;
292+
}
293+
return has_standard_axis_map(idx) && packed_dim_of(idx) == 2;
294+
}
295+
296+
bool ComputeGraph::is_standard_width_packed_texture_tensor(
297+
const ValueRef idx) const {
298+
if (!val_is_tensor(idx)) {
299+
return false;
300+
}
301+
if (is_buffer_storage(idx)) {
302+
return false;
303+
}
304+
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
305+
}
306+
275307
ValueRef ComputeGraph::add_tensor(
276308
const std::vector<int64_t>& sizes,
277309
const vkapi::ScalarType dtype,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class ComputeGraph final {
231231
inline ptr_type get_##short_name(const ValueRef idx) { \
232232
return ptr_type(this, idx); \
233233
} \
234-
inline bool val_is_##short_name(const ValueRef idx) { \
234+
inline bool val_is_##short_name(const ValueRef idx) const { \
235235
return values_.at(idx).is##type_name(); \
236236
}
237237

@@ -314,6 +314,32 @@ class ComputeGraph final {
314314
return values_.at(idx).toConstTensor().has_buffer_storage();
315315
}
316316

317+
/*
318+
* Checks that the following is true:
319+
* 1. The value at `idx` is a tensor
320+
* 2. The tensor at `idx` has buffer storage
321+
* 3. The buffer backed tensor at `idx` has a contiguous memory layout
322+
*/
323+
bool is_contiguous_buffer_tensor(const ValueRef idx) const;
324+
325+
/*
326+
* Checks that the following is true:
327+
* 1. The value at `idx` is a tensor
328+
* 2. The tensor at `idx` has texture storage
329+
* 3. The texture backed tensor at `idx` has a standard axis mapping
330+
* 4. The texture backed tensor at `idx` is channels packed
331+
*/
332+
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;
333+
334+
/*
335+
* Checks that the following is true:
336+
* 1. The value at `idx` is a tensor
337+
* 2. The tensor at `idx` has texture storage
338+
* 3. The texture backed tensor at `idx` has a standard axis mapping
339+
* 4. The texture backed tensor at `idx` is width packed
340+
*/
341+
bool is_standard_width_packed_texture_tensor(const ValueRef idx) const;
342+
317343
inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
318344
const {
319345
return values_.at(maybe_view)
@@ -354,7 +380,7 @@ class ComputeGraph final {
354380
return values_.at(idx).toTensor().numel_ubo();
355381
}
356382

357-
inline bool has_standard_axis_map(const ValueRef idx) {
383+
inline bool has_standard_axis_map(const ValueRef idx) const {
358384
return values_.at(idx).toTensor().has_standard_axis_map();
359385
}
360386

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define BUF_T ${buffer_scalar_type(DTYPE)}
17+
18+
${define_required_extensions(DTYPE)}
19+
20+
layout(std430) buffer;
21+
22+
${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")}
24+
25+
${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")}
26+
27+
${layout_declare_ubo(B, "ivec4", "mean_strides")}
28+
${layout_declare_ubo(B, "int", "mean_numel")}
29+
${layout_declare_ubo(B, "ivec3", "in_limits")}
30+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
31+
32+
layout(push_constant) uniform PRECISION restrict Block {
33+
int group;
34+
float epsilon;
35+
};
36+
37+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
38+
39+
${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")}
40+
const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout);
41+
42+
#define LOCAL_WORK_GROUP_SIZE 64
43+
shared float shared_sum[LOCAL_WORK_GROUP_SIZE];
44+
shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE];
45+
46+
/*
47+
* Computes the mean and standard deviation of one group of channels of the
48+
* input tensor for the group normalization operator.
49+
*
50+
* Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors
51+
* will have a shape of [G, N] where G = C / group.
52+
*
53+
* The input tensor is assumed to be a channels-packed texture tensor with the
54+
* standard axis mapping. The output tensors are assumed to be contiguous buffer
55+
* tensors.
56+
*
57+
* Algorithm:
58+
* 1. Each shader invocation corresponds to one group in one batch
59+
* 2. The local work group cooperatively reduces over all spatial locations (H×W)
60+
* and all channels within the group (C/group channels)
61+
* 3. Uses shared memory for efficient parallel reduction
62+
* 4. Main thread (local ID 0) writes the final mean and rstd to buffer
63+
*
64+
* Global work group size: {N, 1, 1}
65+
* N is the number of elements in the tensor buffer; each thread computes one
66+
* output element.
67+
*
68+
* Local work group size: {1, float, 1}
69+
* float should be a power of 2, recommended 64 or 128 threads. This allows
70+
* efficient tree-based reduction in shared memory. Each local group will
71+
* cooperate to compute the output element.
72+
*
73+
* Each shader invocation will compute the mean and standard deviation for one
74+
* channel group in the input, and write out the corresponding result.
75+
*/
76+
void group_norm_reduce_C_packed() {
77+
const int global_idx = int(gl_GlobalInvocationID.x);
78+
const int local_idx = int(gl_LocalInvocationID.y);
79+
80+
// Calculate group dimensions
81+
const int D = in_sizes.z / group; // channels per group
82+
const int HxW = in_sizes.y * in_sizes.x; // spatial size
83+
const int group_size = D * HxW; // total elements per group
84+
85+
// Convert global index to (group_idx, batch_idx)
86+
const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order);
87+
88+
// Initialize local sums
89+
float local_sum = 0.0;
90+
float local_sum_sq = 0.0;
91+
int local_count = 0;
92+
93+
// Calculate the range of channels for this group
94+
const int group_start_channel = mean_tidx.x * D;
95+
const int group_end_channel = group_start_channel + D;
96+
97+
// Calculate the range of texels that contain channels from this group
98+
const int start_texel_idx = group_start_channel / 4;
99+
const int end_texel_idx = divup4(group_end_channel);
100+
const int texels_in_group = end_texel_idx - start_texel_idx;
101+
102+
// Total texels to process across all spatial locations
103+
const int total_texels = texels_in_group * HxW;
104+
105+
// Each thread processes a subset of texels
106+
const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE;
107+
const int start_texel = local_idx * texels_per_thread;
108+
const int end_texel = min(start_texel + texels_per_thread, total_texels);
109+
110+
// Process assigned texels
111+
for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) {
112+
// Convert texel index to spatial and channel coordinates
113+
const int spatial_idx = texel_idx / texels_in_group;
114+
const int texel_in_group = texel_idx % texels_in_group;
115+
116+
// Convert to spatial coordinates
117+
const int w = spatial_idx % in_sizes.x;
118+
const int h = spatial_idx / in_sizes.x;
119+
120+
// Calculate the global texel index
121+
const int global_texel_idx = start_texel_idx + texel_in_group;
122+
123+
// Convert to texture position using default axis mapping
124+
ivec3 tex_pos = ivec3(w, h, global_texel_idx);
125+
126+
// Adjust for batch dimension if needed
127+
if (in_sizes.w > 1) {
128+
// default axis mapping means channels is the batch concat dim
129+
tex_pos.z += mean_tidx.y * divup4(in_sizes.z);
130+
}
131+
132+
// Check bounds and load texel
133+
if (all(lessThan(tex_pos, in_limits))) {
134+
const vec4 texel_val = load_texel(t_in, tex_pos);
135+
136+
// Process all components of the texel that belong to this group
137+
const int texel_start_channel = global_texel_idx * 4;
138+
for (int comp = 0; comp < 4; comp++) {
139+
const int current_channel = texel_start_channel + comp;
140+
141+
// Check if this component belongs to the current group
142+
if (current_channel >= group_start_channel && current_channel < group_end_channel) {
143+
const float val = texel_val[comp];
144+
local_sum += val;
145+
local_sum_sq += val * val;
146+
local_count++;
147+
}
148+
}
149+
}
150+
}
151+
152+
// Store local results in shared memory
153+
shared_sum[local_idx] = local_sum;
154+
shared_sum_sq[local_idx] = local_sum_sq;
155+
156+
// Synchronize threads
157+
memoryBarrierShared();
158+
barrier();
159+
160+
// Perform tree-based reduction in shared memory
161+
for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) {
162+
if (local_idx < stride) {
163+
shared_sum[local_idx] += shared_sum[local_idx + stride];
164+
shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride];
165+
}
166+
memoryBarrierShared();
167+
barrier();
168+
}
169+
170+
// Main thread writes the result
171+
if (local_idx == 0 && global_idx < mean_numel) {
172+
const float total_sum = shared_sum[0];
173+
const float total_sum_sq = shared_sum_sq[0];
174+
const float count = float(group_size);
175+
176+
// Calculate mean and reciprocal standard deviation
177+
const float mean_val = total_sum / count;
178+
const float variance = (total_sum_sq / count) - (mean_val * mean_val);
179+
const float rstd_val = 1.0 / sqrt(variance + epsilon);
180+
181+
// Write to buffer-backed tensors
182+
t_mean[global_idx] = BUF_T(mean_val);
183+
t_rstd[global_idx] = BUF_T(rstd_val);
184+
}
185+
}
186+
187+
void main() {
188+
group_norm_reduce_C_packed();
189+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
group_norm_reduce_texture:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: group_norm_reduce_texture

0 commit comments

Comments
 (0)