Skip to content

Commit 4d40abe

Browse files
committed
Update on "[ET-VK] Implement `native_group_norm"
## Changes * Add implementation for the group norm operator. The operator is implemented via a 2 stage implementation. First, a reduction operator is executed to calculate the mean and standard deviation of each channel group. Then, the normalization is applied in an elementwise fashion. Differential Revision: [D77038778](https://our.internmc.facebook.com/intern/diff/D77038778/) [ghstack-poisoned]
2 parents 177b310 + 009aa67 commit 4d40abe

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313

1414
#define PRECISION ${PRECISION}
1515

16-
#define VEC4_T ${texel_type(DTYPE)}
17-
18-
#define T ${buffer_scalar_type(DTYPE)}
19-
2016
${define_required_extensions(DTYPE)}
2117

2218
layout(std430) buffer;
@@ -67,8 +63,8 @@ shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE];
6763
* N is the number of elements in the tensor buffer; each thread computes one
6864
* output element.
6965
*
70-
* Local work group size: {1, T, 1}
71-
* T should be a power of 2, recommended 64 or 128 threads. This allows
66+
* Local work group size: {1, float, 1}
67+
* float should be a power of 2, recommended 64 or 128 threads. This allows
7268
* efficient tree-based reduction in shared memory. Each local group will
7369
* cooperate to compute the output element.
7470
*
@@ -133,7 +129,7 @@ void group_norm_reduce_C_packed() {
133129

134130
// Check bounds and load texel
135131
if (all(lessThan(tex_pos, in_limits))) {
136-
const VEC4_T texel_val = load_texel(t_in, tex_pos);
132+
const vec4 texel_val = load_texel(t_in, tex_pos);
137133

138134
// Process all components of the texel that belong to this group
139135
const int texel_start_channel = global_texel_idx * 4;
@@ -181,8 +177,8 @@ void group_norm_reduce_C_packed() {
181177
const float rstd_val = 1.0 / sqrt(variance + epsilon);
182178

183179
// Write to buffer-backed tensors
184-
t_mean[global_idx] = T(mean_val);
185-
t_rstd[global_idx] = T(rstd_val);
180+
t_mean[global_idx] = mean_val;
181+
t_rstd[global_idx] = rstd_val;
186182
}
187183
}
188184

backends/vulkan/test/op_tests/cases.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,12 +651,18 @@ def get_native_group_norm_inputs():
651651
test_suite = VkTestSuite(
652652
[
653653
# (input_shape, weight_shape, bias_shape, N, C, HxW, group, eps)
654+
# General test cases
654655
((1, 8, 4, 4), (8), (8), 1, 8, 16, 2, 0.001),
655656
((2, 8, 3, 3), (8), (8), 2, 8, 9, 4, 0.001),
656657
((1, 12, 2, 2), (12), (12), 1, 12, 4, 3, 0.001),
657658
((3, 16, 5, 5), (16), (16), 3, 16, 25, 8, 0.001),
659+
((3, 16, 13, 17), (16), (16), 3, 16, 13 * 17, 4, 0.001),
658660
((1, 4, 7, 7), (4), (4), 1, 4, 49, 2, 0.001),
659661
((2, 6, 1, 8), (6), (6), 2, 6, 8, 3, 0.001),
662+
# Single group and prime number sizes
663+
((3, 7, 13, 11), (7), (7), 3, 7, 13 * 11, 1, 0.001),
664+
# Each channel is it's own group and prime number sizes
665+
((1, 7, 13, 11), (7), (7), 1, 7, 13 * 11, 7, 0.001),
660666
]
661667
)
662668
test_suite.layouts = [
@@ -667,6 +673,7 @@ def get_native_group_norm_inputs():
667673
]
668674
test_suite.dtypes = [
669675
"at::kFloat",
676+
"at::kHalf",
670677
]
671678
test_suite.arg_storage_types = {
672679
"out": [None, "utils::kBuffer", "utils::kBuffer"],

0 commit comments

Comments
 (0)