Skip to content

Commit 3284023

Browse files
committed
Update on "[ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator"
## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/) [ghstack-poisoned]
2 parents 82684dd + d7607af commit 3284023

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
self.default_layout: VkMemoryLayout = default_memory_layout
9292
self.texture_limits = texture_limits
9393

94-
def propose_node_storage(
94+
def propose_node_storage( # noqa: C901
9595
self,
9696
node: torch.fx.Node,
9797
) -> Optional[VkStorageType]:

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
660660
exir_ops.edge.aten.native_group_norm.default,
661661
]
662662
)
663-
def register_ported_ops_with_prepacking(features: OpFeatures):
663+
def register_native_group_norm(features: OpFeatures):
664664
features.texture_impl = TextureImplFeatures(
665665
valid_packed_dims={PackedDim.CHANNELS},
666666
)

backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1615

1716
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1817

backends/vulkan/test/op_tests/cases.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
test_suites = {}
2828

2929

30-
def register_test_suite(aten_op, skip=True):
30+
def register_test_suite(aten_op):
3131
def test_suite_decorator(fn: Callable) -> Callable:
32-
if skip:
33-
return fn
3432
if isinstance(aten_op, str):
3533
test_suites[aten_op] = fn()
3634
elif isinstance(aten_op, list):
@@ -143,7 +141,7 @@ def get_linear_inputs():
143141
inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list]
144142
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
145143
inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list]
146-
inputs_list += [((3, 6, K), (N, K), (N)) for _, K, N in MKN_list]
144+
inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list]
147145

148146
test_suite = VkTestSuite(inputs_list)
149147
test_suite.dtypes = ["at::kFloat"]
@@ -648,7 +646,7 @@ def get_native_layer_norm_inputs():
648646
return test_suite
649647

650648

651-
@register_test_suite("aten.native_group_norm.default", skip=False)
649+
@register_test_suite("aten.native_group_norm.default")
652650
def get_native_group_norm_inputs():
653651
test_suite = VkTestSuite(
654652
[
@@ -662,24 +660,18 @@ def get_native_group_norm_inputs():
662660
]
663661
)
664662
test_suite.layouts = [
665-
# "utils::kWidthPacked",
666-
# "utils::kHeightPacked",
667663
"utils::kChannelsPacked",
668664
]
669665
test_suite.storage_types = [
670-
# "utils::kBuffer",
671-
"utils::kTexture3D"
666+
"utils::kTexture3D",
667+
]
668+
test_suite.dtypes = [
669+
"at::kFloat",
672670
]
673-
test_suite.dtypes = ["at::kFloat"]
674671
test_suite.arg_storage_types = {
675672
"out": [None, "utils::kBuffer", "utils::kBuffer"],
676673
}
677674

678-
# test_suite.arg_data_gen_fn = {
679-
# "weight": "make_zeros_tensor",
680-
# "bias": "make_rand_tensor",
681-
# }
682-
683675
test_suite.prepacked_args = ["weight", "bias"]
684676
test_suite.requires_prepack = True
685677

@@ -798,7 +790,6 @@ def get_permute_inputs():
798790
]
799791
test_suite.dtypes = [
800792
"at::kFloat",
801-
"at::kInt",
802793
]
803794
return test_suite
804795

0 commit comments

Comments
 (0)