Skip to content

Commit 324f021

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
update native_layer_norm to new layout gen & axis mapping (#6358)
Summary: Pull Request resolved: #6358 Naively using ivec4 axis mapping regresses latency by 20-30% for layer norm, due to the added overhead of another layer of index lookups over the 2 loops over the entire width dim. We can use specialization constants to move the index lookups ahead of time to the shader compilation and command buffer construction phase. Unfortunately, we can't pass vec types as specialization constants. But, we can squeeze the axis mapping into a single 32-bit int and pass that in as a specialization constant! We can unpack the int and create a const ivec4 axis map which can be folded during shader compilation. Using this method, we incur a 1% overhead instead of the 20+% we previously saw. This diff also adds a codegen function for specialization constants, along with a new accumulator `C` for constant ids (besides `B` for binding index for textures, buffers and buffer objects) Reviewed By: SS-JIA Differential Revision: D63361329 fbshipit-source-id: d1ff1f20d6eb098e4b923de469cb2dad20f16044
1 parent b3932c0 commit 324f021

File tree

6 files changed

+87
-32
lines changed

6 files changed

+87
-32
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
# pyre-unsafe
9+
810
import argparse
911
import array
1012
import codecs
@@ -42,6 +44,10 @@
4244
# layout binding index when declaring layout bindings. Note that a container
4345
# type is used because integers are immutable in Python.
4446
"B": [0],
47+
# C is shorthand for "constant_id". This is used to automatically increment the
48+
# constant_id index for specialization constants.
49+
# Note that it starts at 3, as 0-2 are reserved for local workgroup size ids.
50+
"C": [3],
4551
}
4652

4753
# Establishes relationships between different tensor types and different GLSL types
@@ -300,14 +306,32 @@ def layout_declare_ubo(
300306
layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{
301307
"""
302308
for type_name, var_name in var_list:
303-
out_str += f"{type_name} {var_name};\n"
309+
out_str += f" {type_name} {var_name};\n"
304310
out_str += "};"
305311

306312
if isinstance(slot, list):
307313
slot[0] = slot[0] + 1
308314
return out_str
309315

310316

317+
def layout_declare_spec_const(
318+
slot: Union[int, List[int]],
319+
type_name: str,
320+
var_name: str,
321+
initial_val: Optional[str] = None,
322+
) -> str:
323+
assert type_name in ["int", "uint", "float", "bool"]
324+
325+
out_str = f"layout(constant_id = {get_slot_val(slot)}) const {type_name} {var_name}"
326+
if initial_val is not None:
327+
out_str += f" = {initial_val}"
328+
out_str += ";"
329+
330+
if isinstance(slot, list):
331+
slot[0] = slot[0] + 1
332+
return out_str
333+
334+
311335
def define_active_storage_type(storage_type: str):
312336
if storage_type.lower() == "buffer":
313337
return "#define USING_BUFFER"
@@ -361,6 +385,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
361385
"layout_declare_sampler": layout_declare_sampler,
362386
"layout_declare_tensor": layout_declare_tensor,
363387
"layout_declare_ubo": layout_declare_ubo,
388+
"layout_declare_spec_const": layout_declare_spec_const,
364389
"define_active_storage_type": define_active_storage_type,
365390
"define_required_extensions": define_required_extensions,
366391
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
232232
imageStore(im, lpos_to_pos(lpos, axis_map), texel)
233233
#endif
234234

235+
// Converts hashed axis mapping and packed dim to a ivec4
236+
// e.g. 0x000102, 2 -> ivec4(0, 1, 2, 2)
237+
// e.g. 0x010200, 1 -> ivec4(1, 2, 0, 1)
238+
#define UNHASH_AXIS_MAP(hash, packed_dim) \
239+
ivec4(hash >> 16, (hash >> 8) & 0xFF, hash & 0xFF, packed_dim)
240+
#define DEFAULT_AXIS_MAP_HASH 0x000102
241+
235242
/************************
236243
* Deprecated Functions *
237244
************************/

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,32 @@
1717

1818
layout(std430) buffer;
1919

20-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21-
layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean;
22-
layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd;
20+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "w", "t_mean", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "w", "t_rstd", DTYPE, STORAGE)}
2323

24-
layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in;
25-
layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in;
26-
layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in;
24+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
25+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
26+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)}
2727

28-
layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits {
29-
ivec3 out_limits;
30-
};
28+
${layout_declare_ubo(B, "ivec3", "out_limits")}
29+
${layout_declare_ubo(B, "ivec4", "sizes")}
30+
${layout_declare_ubo(B, "float", "epsilon")}
3131

32-
layout(set = 0, binding = 7) uniform PRECISION restrict Sizes {
33-
ivec4 sizes;
34-
};
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3533

36-
layout(set = 0, binding = 8) uniform PRECISION restrict Epsilon {
37-
float epsilon;
38-
};
34+
${layout_declare_spec_const(C, "int", "in_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")}
35+
${layout_declare_spec_const(C, "int", "in_packed_dim", "C_DIM")}
36+
const ivec4 in_axis_map = UNHASH_AXIS_MAP(in_axis_map_hash, in_packed_dim);
3937

40-
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
38+
${layout_declare_spec_const(C, "int", "out_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")}
39+
${layout_declare_spec_const(C, "int", "out_packed_dim", "C_DIM")}
40+
const ivec4 out_axis_map = UNHASH_AXIS_MAP(out_axis_map_hash, out_packed_dim);
4141

4242
void main() {
43-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
43+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
4444

45-
if (any(greaterThanEqual(pos, out_limits))) {
45+
if (any(greaterThanEqual(lpos, out_limits))) {
4646
return;
4747
}
4848

@@ -55,8 +55,10 @@ void main() {
5555

5656
// Use Welford's online algorithm to compute mean and variance in one pass
5757
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
58+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
5859
for (int w = 0; w < width; ++w) {
59-
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
60+
in_pos[in_axis_map.x] = w;
61+
VEC4_T v = load_texel(t_in, in_pos);
6062
delta = v - mean;
6163
mean += delta / (w + 1);
6264
delta2 = v - mean;
@@ -68,14 +70,15 @@ void main() {
6870
VEC4_T offset = -rstd * mean;
6971

7072
for (int w = 0; w < width; ++w) {
71-
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
73+
in_pos[in_axis_map.x] = w;
74+
VEC4_T v = load_texel(t_in, in_pos);
7275
// broadcasting
73-
VEC4_T weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx;
74-
VEC4_T bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx;
76+
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
77+
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
7578
VEC4_T outtex = (v * rstd + offset) * weight + bias;
76-
imageStore(image_out, ivec3(w, pos.y, pos.z), outtex);
79+
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
7780
}
7881

79-
imageStore(image_mean, pos, mean);
80-
imageStore(image_rstd, pos, rstd);
82+
write_texel(t_mean, lpos, mean);
83+
write_texel(t_rstd, lpos, rstd);
8184
}

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
native_layer_norm:
88
parameter_names_with_default_values:
9-
NDIM: 3
109
DTYPE: float
11-
PACKING: C_packed
10+
STORAGE: texture3d
1211
generate_variant_forall:
1312
DTYPE:
1413
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,18 @@ void add_native_layer_norm_node(
106106
vkapi::MemoryAccessType::WRITE},
107107
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
108108
// Shader params buffers
109-
{t_out->logical_limits_ubo(),
110-
t_out->sizes_ubo(),
111-
graph.create_params_buffer(epsilon)},
109+
{
110+
t_out->logical_limits_ubo(),
111+
t_out->sizes_ubo(),
112+
graph.create_params_buffer(epsilon),
113+
},
112114
// Specialization Constants
113-
{},
115+
{
116+
hash_axis_map(t_input->axis_map()),
117+
t_input->packed_dim(),
118+
hash_axis_map(t_out->axis_map()),
119+
t_out->packed_dim(),
120+
},
114121
// Resizing Logic
115122
resize_native_layer_norm_node,
116123
{normalized_shape}));

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,18 @@ T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) {
7979
return ndim - 1 - nchw_dim;
8080
}
8181

82+
//
83+
// Tensor axis map utilities
84+
//
85+
86+
// Converts ivec4 axis map to a single int32_t, to be able to pass it as a
87+
// specialization constant instead of a ubo. This allows for the spir-v to
88+
// bytecode compilation to perform compile-time folding on the axis map.
89+
// Only converts the first 3 indices, as the last index is the packed dim,
90+
// which is passed separately.
91+
// Example: ivec4(0, 1, 2, 2) -> 0x000102
92+
inline int32_t hash_axis_map(const std::vector<int64_t>& axis_map) {
93+
return (axis_map.at(0) << 16) + (axis_map.at(1) << 8) + axis_map.at(2);
94+
}
95+
8296
} // namespace vkcompute

0 commit comments

Comments
 (0)