Skip to content

Commit 9cb166b

Browse files
committed
Update on "[ET-VK] Minor performance improvements to native layer norm."
This diff introduces minor performance improvements to the native layer norm function in the Vulkan backend of Executorch. In this new approach: The mean and variance values are calculated in 2 separate passes. Shader is dispatched based on input texture size, and input texel is read and stored in shared memory. Input stored in shard memory is then summed up using a reduce function. This implementation better utilizes a GPUs parallel processing capabilities. Differential Revision: [D72430290](https://our.internmc.facebook.com/intern/diff/D72430290/) [ghstack-poisoned]
2 parents 1407ff7 + 932def6 commit 9cb166b

File tree

8 files changed

+57
-11
lines changed

8 files changed

+57
-11
lines changed

backends/xnnpack/operators/op_slice_copy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def define_node(
6969
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
7070
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
7171

72-
slice_begin_index = cast(int, node.args[2])
72+
slice_begin_index = 0
73+
if len(node.args) > 2 and node.args[2]:
74+
slice_begin_index = cast(int, node.args[2])
7375
if slice_begin_index < 0:
7476
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index
7577

backends/xnnpack/test/ops/test_slice_copy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def forward(self, x):
6969
# Note that two of the slices are optimized away as they are identity.
7070
self._test_slice_copy(ConvSlice(), inputs, 4, 2)
7171

72+
def test_fp32_slice_copy_default_start(self):
73+
"""
74+
XNNPACK supports default start in slice op.
75+
"""
76+
77+
class Slice(torch.nn.Module):
78+
def forward(self, x):
79+
return torch.ops.aten.slice.Tensor(x, 0, None, 2)
80+
81+
inputs = (torch.randn(5, 5),)
82+
self._test_slice_copy(Slice(), inputs, 1, 1)
83+
7284
def test_fp32_slice_copy_stride_non_1(self):
7385
"""
7486
XNNPACK does not support strided slicing.

extension/parallel/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ def define_common_targets():
1717
"@EXECUTORCH_CLIENTS",
1818
],
1919
deps = [
20-
"//executorch/runtime/kernel:thread_parallel_interface",
20+
"//executorch/extension/threadpool:threadpool",
2121
],
2222
)

extension/threadpool/targets.bzl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def define_common_targets():
2020
] + (["fb/threadpool_use_n_threads.h"] if not runtime.is_oss else [])
2121

2222
runtime.cxx_library(
23-
name = "threadpool",
23+
name = "threadpool_lib",
2424
srcs = _THREADPOOL_SRCS,
2525
deps = [
2626
"//executorch/runtime/core:core",
@@ -45,6 +45,38 @@ def define_common_targets():
4545
],
4646
)
4747

48+
runtime.cxx_library(
49+
name = "threadpool",
50+
# TODO: OSS doesn't have os:iphoneos. Sync buck2 prelude
51+
# update to add it and remove duplication.
52+
exported_deps = (select({
53+
# Major operating systems should be able to use threadpool.
54+
"ovr_config//os:linux": [":threadpool_lib"],
55+
"ovr_config//os:macos": [":threadpool_lib"],
56+
"ovr_config//os:windows": [":threadpool_lib"],
57+
"ovr_config//os:android": [":threadpool_lib"],
58+
"ovr_config//os:iphoneos": [":threadpool_lib"],
59+
# Machines without an operating system shouldn't.
60+
"ovr_config//os:none": ["//executorch/runtime/kernel:thread_parallel_interface"],
61+
# If we don't know what it is, disable threadpool out of caution.
62+
"DEFAULT": ["//executorch/runtime/kernel:thread_parallel_interface"],
63+
}) if not runtime.is_oss else select({
64+
# Major operating systems should be able to use threadpool.
65+
"ovr_config//os:linux": [":threadpool_lib"],
66+
"ovr_config//os:macos": [":threadpool_lib"],
67+
"ovr_config//os:windows": [":threadpool_lib"],
68+
"ovr_config//os:android": [":threadpool_lib"],
69+
# Machines without an operating system shouldn't.
70+
"ovr_config//os:none": ["//executorch/runtime/kernel:thread_parallel_interface"],
71+
# If we don't know what it is, disable threadpool out of caution.
72+
"DEFAULT": ["//executorch/runtime/kernel:thread_parallel_interface"],
73+
})),
74+
visibility = [
75+
"//executorch/...",
76+
"@EXECUTORCH_CLIENTS",
77+
],
78+
)
79+
4880
runtime.cxx_library(
4981
name = "cpuinfo_utils",
5082
srcs = [

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ _OPTIMIZED_ATEN_OPS = (
107107
op_target(
108108
name = "op_where",
109109
deps = [
110+
"//executorch/extension/threadpool:threadpool",
110111
"//executorch/kernels/portable/cpu/util:elementwise_util",
111-
"//executorch/runtime/kernel:thread_parallel_interface",
112112
],
113113
),
114114
)

kernels/optimized/lib_defs.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def define_libs(is_fbcode=False):
232232
"DEFAULT": [],
233233
}) + LIBBLAS_DEPS,
234234
exported_deps = [
235+
"//executorch/extension/threadpool:threadpool",
235236
"//executorch/kernels/optimized:libutils",
236237
"//executorch/runtime/core/exec_aten:lib",
237-
"//executorch/runtime/kernel:thread_parallel_interface",
238238
],
239239
**get_apple_framework_deps_kwargs(is_fbcode),
240240
)

kernels/portable/cpu/util/targets.bzl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def define_common_targets():
1212
runtime.cxx_library(
1313
name = "all_deps",
1414
deps = [
15+
"//executorch/extension/threadpool:threadpool",
1516
"//executorch/kernels/portable/cpu/util:functional_util",
1617
"//executorch/kernels/portable/cpu/util:broadcast_util",
1718
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
@@ -32,7 +33,6 @@ def define_common_targets():
3233
"//executorch/kernels/portable/cpu/util:slice_util",
3334
"//executorch/kernels/portable/cpu/util:elementwise_util",
3435
"//executorch/kernels/portable/cpu/util:upsample_util",
35-
"//executorch/runtime/kernel:thread_parallel_interface",
3636
],
3737
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
3838
)
@@ -111,7 +111,7 @@ def define_common_targets():
111111
":broadcast_util",
112112
":dtype_util",
113113
"//executorch/runtime/kernel:kernel_runtime_context",
114-
"//executorch/runtime/kernel:thread_parallel_interface",
114+
"//executorch/extension/threadpool:threadpool",
115115
],
116116
deps = [
117117
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -245,7 +245,7 @@ def define_common_targets():
245245
srcs = [],
246246
exported_headers = ["functional_util.h"],
247247
exported_deps = [
248-
"//executorch/runtime/kernel:thread_parallel_interface",
248+
"//executorch/extension/threadpool:threadpool",
249249
],
250250
deps = [
251251
"//executorch/runtime/kernel:kernel_includes",
@@ -319,7 +319,7 @@ def define_common_targets():
319319
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
320320
],
321321
exported_deps = [
322-
"//executorch/runtime/kernel:thread_parallel_interface",
322+
"//executorch/extension/threadpool:threadpool",
323323
],
324324
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
325325
visibility = [

runtime/kernel/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def define_common_targets():
5959
"//executorch/runtime/core/portable_type/c10/c10:c10",
6060
"//executorch/runtime/platform:platform",
6161
],
62+
# Don't depend on this target, depend on //executorch/extension/threadpool:threadpool.
6263
visibility = [
63-
"//executorch/...",
64-
"@EXECUTORCH_CLIENTS",
64+
"//executorch/extension/threadpool/...",
6565
],
6666
)
6767

0 commit comments

Comments
 (0)