From 9f58e6afd6a70ee664bae8c30d00f72699d60720 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Thu, 1 May 2025 12:10:39 -0700 Subject: [PATCH] Minor changes to native layer norm shader op to improve perf. (#10585) Summary: This diff improves perf by changing native layer norm shader to accumulate result in local variable instead of shared memory, and do a shared memory pass later. Reviewed By: SS-JIA Differential Revision: D73864950 --- .../graph/ops/glsl/native_layer_norm.glsl | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index c3e53cbfc3b..7897f0e8133 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout); // First iteration of reduce will have 32 threads sum up 64 elements. // Second iteration will have 32 threads sum up 16 elements from previous iteration and so on. // Thus thread utilization starts at 100%. -#define SHARED_MEMORY_FACTOR 2 +#define SHARED_MEMORY_FACTOR 1 -#define offset_pos_index(index) ((index) + ((index) >> 2)) +#define offset_pos_index(index) ((index) + ((index) >> 3)) shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)]; @@ -154,14 +154,13 @@ void reduce_non_packed_dim() { if (all(lessThan(in_pos, out_limits))) { in_val = load_texel(t_in, in_pos); } - shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val; + mean += in_val; } - - reduce_input(width_stride, shared_idx_offset); - mean += shared_input[offset_pos_index(shared_idx_offset)]; } - mean /= width; + shared_input[offset_pos_index(shared_idx)] = mean; + reduce_input(width_stride, shared_idx_offset); + mean = shared_input[offset_pos_index(shared_idx_offset)] / width; memoryBarrierShared(); barrier(); @@ -178,14 +177,13 @@ void reduce_non_packed_dim() { } const VEC4_T delta = in_val - mean; - shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta; + var += delta * delta; } - - reduce_input(width_stride, shared_idx_offset); - var += shared_input[offset_pos_index(shared_idx_offset)]; } - var /= width; + shared_input[offset_pos_index(shared_idx)] = var; + reduce_input(width_stride, shared_idx_offset); + var = shared_input[offset_pos_index(shared_idx_offset)] / width; VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); VEC4_T offset = -rstd * mean; @@ -226,6 +224,7 @@ void reduce_packed_dim() { const int in_pos_x_limit = out_limits[in_axis_map.x]; + VEC4_T accum = VEC4_T(0); // Loop over the width in stride increments for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) { // Read input in shared memory @@ -244,20 +243,20 @@ void reduce_packed_dim() { in_val.z = mix(in_val.z, T(0), remain_inv > 1); in_val.w = mix(in_val.w, T(0), remain_inv > 0); } - - shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val; + accum += in_val; } - - reduce_input(width_stride, shared_idx_offset); - const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)]; - mean += val.x + val.y + val.z + val.w; } - mean /= width; + shared_input[offset_pos_index(shared_idx)] = accum; + reduce_input(width_stride, shared_idx_offset); + VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)]; + mean = (val.x + val.y + val.z + val.w) / width; memoryBarrierShared(); barrier(); + VEC4_T delta2 = VEC4_T(0); + // Loop over the width in stride increments for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) { // Read input in shared memory @@ -278,16 +277,14 @@ void reduce_packed_dim() { } const VEC4_T delta = in_val - mean; - const VEC4_T delta2 = delta * delta; - shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2; + delta2 += delta * delta; } - - reduce_input(width_stride, shared_idx_offset); - const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)]; - var += val.x + val.y + val.z + val.w; } - var /= width; + shared_input[offset_pos_index(shared_idx)] = delta2; + reduce_input(width_stride, shared_idx_offset); + val = shared_input[offset_pos_index(shared_idx_offset)]; + var = (val.x + val.y + val.z + val.w) / width; T rstd = pow(var + epsilon, T(-0.5)); T offset = -rstd * mean;