Skip to content

Commit c8d0244

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
update native_layer_norm to new layout gen & axis mapping
Summary: Naively using ivec4 axis mapping regresses latency by 20-30% for layer norm, due to the added overhead of another layer of index lookups over the 2 loops over the entire width dim. We can use specialization constants to move the index lookups ahead of time to the shader compilation and command buffer construction phase. Unfortunately, we can't pass vec types as specialization constants. But, we can squeeze the axis mapping into a single 32-bit int and pass that in as a specialization constant! We can unpack the int and create a const ivec4 axis map which can be folded during shader compilation. Using this method, we incur a 1% overhead instead of the 20+% we previously saw. This diff also adds a codegen function for specialization constants, along with a new accumulator `C` for constant ids (besides `B` for binding index for textures, buffers and buffer objects) Reviewed By: SS-JIA Differential Revision: D63361329
1 parent 7493aae commit c8d0244

File tree

6 files changed

+85
-32
lines changed

6 files changed

+85
-32
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
# layout binding index when declaring layout bindings. Note that a container
4343
# type is used because integers are immutable in Python.
4444
"B": [0],
45+
# C is shorthand for "constant_id". This is used to automatically increment the
46+
# constant_id index for specialization constants.
47+
# Note that it starts at 3, as 0-2 are reserved for local workgroup size ids.
48+
"C": [3],
4549
}
4650

4751
# Establishes relationships between different tensor types and different GLSL types
@@ -300,14 +304,32 @@ def layout_declare_ubo(
300304
layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{
301305
"""
302306
for type_name, var_name in var_list:
303-
out_str += f"{type_name} {var_name};\n"
307+
out_str += f" {type_name} {var_name};\n"
304308
out_str += "};"
305309

306310
if isinstance(slot, list):
307311
slot[0] = slot[0] + 1
308312
return out_str
309313

310314

315+
def layout_declare_spec_const(
316+
slot: Union[int, List[int]],
317+
type_name: str,
318+
var_name: str,
319+
initial_val: Optional[str] = None,
320+
) -> str:
321+
assert type_name in ["int", "uint", "float", "bool"]
322+
323+
out_str = f"layout(constant_id = {get_slot_val(slot)}) const {type_name} {var_name}"
324+
if initial_val is not None:
325+
out_str += f" = {initial_val}"
326+
out_str += ";"
327+
328+
if isinstance(slot, list):
329+
slot[0] = slot[0] + 1
330+
return out_str
331+
332+
311333
def define_active_storage_type(storage_type: str):
312334
if storage_type.lower() == "buffer":
313335
return "#define USING_BUFFER"
@@ -361,6 +383,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
361383
"layout_declare_sampler": layout_declare_sampler,
362384
"layout_declare_tensor": layout_declare_tensor,
363385
"layout_declare_ubo": layout_declare_ubo,
386+
"layout_declare_spec_const": layout_declare_spec_const,
364387
"define_active_storage_type": define_active_storage_type,
365388
"define_required_extensions": define_required_extensions,
366389
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
232232
imageStore(im, lpos_to_pos(lpos, axis_map), texel)
233233
#endif
234234

235+
// Converts hashed axis mapping and packed dim to a ivec4
236+
// e.g. 0x000102, 2 -> ivec4(0, 1, 2, 2)
237+
// e.g. 0x010200, 1 -> ivec4(1, 2, 0, 1)
238+
#define UNHASH_AXIS_MAP(hash, packed_dim) \
239+
ivec4(hash >> 16, (hash >> 8) & 0xFF, hash & 0xFF, packed_dim)
240+
#define DEFAULT_AXIS_MAP_HASH 0x000102
241+
235242
/************************
236243
* Deprecated Functions *
237244
************************/

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,32 @@
1717

1818
layout(std430) buffer;
1919

20-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21-
layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean;
22-
layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd;
20+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "w", "t_mean", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "w", "t_rstd", DTYPE, STORAGE)}
2323

24-
layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in;
25-
layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in;
26-
layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in;
24+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
25+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
26+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)}
2727

28-
layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits {
29-
ivec3 out_limits;
30-
};
28+
${layout_declare_ubo(B, "ivec3", "out_limits")}
29+
${layout_declare_ubo(B, "ivec4", "sizes")}
30+
${layout_declare_ubo(B, "float", "epsilon")}
3131

32-
layout(set = 0, binding = 7) uniform PRECISION restrict Sizes {
33-
ivec4 sizes;
34-
};
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3533

36-
layout(set = 0, binding = 8) uniform PRECISION restrict Epsilon {
37-
float epsilon;
38-
};
34+
${layout_declare_spec_const(C, "int", "in_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")}
35+
${layout_declare_spec_const(C, "int", "in_packed_dim", "C_DIM")}
36+
const ivec4 in_axis_map = UNHASH_AXIS_MAP(in_axis_map_hash, in_packed_dim);
3937

40-
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
38+
${layout_declare_spec_const(C, "int", "out_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")}
39+
${layout_declare_spec_const(C, "int", "out_packed_dim", "C_DIM")}
40+
const ivec4 out_axis_map = UNHASH_AXIS_MAP(out_axis_map_hash, out_packed_dim);
4141

4242
void main() {
43-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
43+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
4444

45-
if (any(greaterThanEqual(pos, out_limits))) {
45+
if (any(greaterThanEqual(lpos, out_limits))) {
4646
return;
4747
}
4848

@@ -55,8 +55,10 @@ void main() {
5555

5656
// Use Welford's online algorithm to compute mean and variance in one pass
5757
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
58+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
5859
for (int w = 0; w < width; ++w) {
59-
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
60+
in_pos[in_axis_map.x] = w;
61+
VEC4_T v = load_texel(t_in, in_pos);
6062
delta = v - mean;
6163
mean += delta / (w + 1);
6264
delta2 = v - mean;
@@ -68,14 +70,15 @@ void main() {
6870
VEC4_T offset = -rstd * mean;
6971

7072
for (int w = 0; w < width; ++w) {
71-
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
73+
in_pos[in_axis_map.x] = w;
74+
VEC4_T v = load_texel(t_in, in_pos);
7275
// broadcasting
73-
VEC4_T weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx;
74-
VEC4_T bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx;
76+
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
77+
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
7578
VEC4_T outtex = (v * rstd + offset) * weight + bias;
76-
imageStore(image_out, ivec3(w, pos.y, pos.z), outtex);
79+
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
7780
}
7881

79-
imageStore(image_mean, pos, mean);
80-
imageStore(image_rstd, pos, rstd);
82+
write_texel(t_mean, lpos, mean);
83+
write_texel(t_rstd, lpos, rstd);
8184
}

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
native_layer_norm:
88
parameter_names_with_default_values:
9-
NDIM: 3
109
DTYPE: float
11-
PACKING: C_packed
10+
STORAGE: texture3d
1211
generate_variant_forall:
1312
DTYPE:
1413
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,18 @@ void add_native_layer_norm_node(
109109
vkapi::MemoryAccessType::WRITE},
110110
{{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
111111
// Shader params buffers
112-
{t_out->logical_limits_ubo(),
113-
t_out->sizes_ubo(),
114-
graph.create_params_buffer(epsilon)},
112+
{
113+
t_out->logical_limits_ubo(),
114+
t_out->sizes_ubo(),
115+
graph.create_params_buffer(epsilon),
116+
},
115117
// Specialization Constants
116-
{},
118+
{
119+
hash_axis_map(t_input->axis_map()),
120+
t_input->packed_dim(),
121+
hash_axis_map(t_out->axis_map()),
122+
t_out->packed_dim(),
123+
},
117124
// Resizing Logic
118125
resize_native_layer_norm_node,
119126
{normalized_shape}));

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,18 @@ T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) {
7979
return ndim - 1 - nchw_dim;
8080
}
8181

82+
//
83+
// Tensor axis map utilities
84+
//
85+
86+
// Converts ivec4 axis map to a single uint32_t, to be able to pass it as a
87+
// specialization constant instead of a ubo. This allows for the spir-v to
88+
// bytecode compilation to perform compile-time folding on the axis map.
89+
// Only converts the first 3 indices, as the last index is the packed dim,
90+
// which is passed separately.
91+
// Example: ivec4(0, 1, 2, 2) -> 0x000102
92+
inline int32_t hash_axis_map(const std::vector<int64_t>& axis_map) {
93+
return (axis_map.at(0) << 16) + (axis_map.at(1) << 8) + axis_map.at(2);
94+
}
95+
8296
} // namespace vkcompute

0 commit comments

Comments
 (0)