Skip to content

Commit 926a760

Browse files
committed
Update base for Update on "[ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator"
## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/) [ghstack-poisoned]
1 parent a9958a3 commit 926a760

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313

1414
#define PRECISION ${PRECISION}
1515

16-
#define VEC4_T ${texel_type(DTYPE)}
17-
18-
#define T ${buffer_scalar_type(DTYPE)}
19-
2016
${define_required_extensions(DTYPE)}
2117

2218
layout(std430) buffer;
@@ -67,8 +63,8 @@ shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE];
6763
* N is the number of elements in the tensor buffer; each thread computes one
6864
* output element.
6965
*
70-
* Local work group size: {1, T, 1}
71-
* T should be a power of 2, recommended 64 or 128 threads. This allows
66+
* Local work group size: {1, float, 1}
67+
* float should be a power of 2, recommended 64 or 128 threads. This allows
7268
* efficient tree-based reduction in shared memory. Each local group will
7369
* cooperate to compute the output element.
7470
*
@@ -133,7 +129,7 @@ void group_norm_reduce_C_packed() {
133129

134130
// Check bounds and load texel
135131
if (all(lessThan(tex_pos, in_limits))) {
136-
const VEC4_T texel_val = load_texel(t_in, tex_pos);
132+
const vec4 texel_val = load_texel(t_in, tex_pos);
137133

138134
// Process all components of the texel that belong to this group
139135
const int texel_start_channel = global_texel_idx * 4;
@@ -181,8 +177,8 @@ void group_norm_reduce_C_packed() {
181177
const float rstd_val = 1.0 / sqrt(variance + epsilon);
182178

183179
// Write to buffer-backed tensors
184-
t_mean[global_idx] = T(mean_val);
185-
t_rstd[global_idx] = T(rstd_val);
180+
t_mean[global_idx] = mean_val;
181+
t_rstd[global_idx] = rstd_val;
186182
}
187183
}
188184

backends/vulkan/test/op_tests/cases.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,12 +651,18 @@ def get_native_group_norm_inputs():
651651
test_suite = VkTestSuite(
652652
[
653653
# (input_shape, weight_shape, bias_shape, N, C, HxW, group, eps)
654+
# General test cases
654655
((1, 8, 4, 4), (8), (8), 1, 8, 16, 2, 0.001),
655656
((2, 8, 3, 3), (8), (8), 2, 8, 9, 4, 0.001),
656657
((1, 12, 2, 2), (12), (12), 1, 12, 4, 3, 0.001),
657658
((3, 16, 5, 5), (16), (16), 3, 16, 25, 8, 0.001),
659+
((3, 16, 13, 17), (16), (16), 3, 16, 13 * 17, 4, 0.001),
658660
((1, 4, 7, 7), (4), (4), 1, 4, 49, 2, 0.001),
659661
((2, 6, 1, 8), (6), (6), 2, 6, 8, 3, 0.001),
662+
# Single group and prime number sizes
663+
((3, 7, 13, 11), (7), (7), 3, 7, 13 * 11, 1, 0.001),
664+
# Each channel is it's own group and prime number sizes
665+
((1, 7, 13, 11), (7), (7), 1, 7, 13 * 11, 7, 0.001),
660666
]
661667
)
662668
test_suite.layouts = [
@@ -667,6 +673,7 @@ def get_native_group_norm_inputs():
667673
]
668674
test_suite.dtypes = [
669675
"at::kFloat",
676+
"at::kHalf",
670677
]
671678
test_suite.arg_storage_types = {
672679
"out": [None, "utils::kBuffer", "utils::kBuffer"],

0 commit comments

Comments
 (0)