Skip to content

Commit ad03e24

Browse files
committed
Update on "[ET-VK] Implement `native_group_norm"
## Changes * Add implementation for the group norm operator. The operator is implemented via a 2 stage implementation. First, a reduction operator is executed to calculate the mean and standard deviation of each channel group. Then, the normalization is applied in an elementwise fashion. Differential Revision: [D77038778](https://our.internmc.facebook.com/intern/diff/D77038778/) [ghstack-poisoned]
2 parents aae67bf + fe02090 commit ad03e24

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1615

1716
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1817

backends/vulkan/test/op_tests/cases.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
test_suites = {}
2828

2929

30-
def register_test_suite(aten_op, skip=True):
30+
def register_test_suite(aten_op):
3131
def test_suite_decorator(fn: Callable) -> Callable:
32-
if skip:
33-
return fn
3432
if isinstance(aten_op, str):
3533
test_suites[aten_op] = fn()
3634
elif isinstance(aten_op, list):
@@ -143,7 +141,7 @@ def get_linear_inputs():
143141
inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list]
144142
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
145143
inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list]
146-
inputs_list += [((3, 6, K), (N, K), (N)) for _, K, N in MKN_list]
144+
inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list]
147145

148146
test_suite = VkTestSuite(inputs_list)
149147
test_suite.dtypes = ["at::kFloat"]
@@ -648,7 +646,7 @@ def get_native_layer_norm_inputs():
648646
return test_suite
649647

650648

651-
@register_test_suite("aten.native_group_norm.default", skip=False)
649+
@register_test_suite("aten.native_group_norm.default")
652650
def get_native_group_norm_inputs():
653651
test_suite = VkTestSuite(
654652
[
@@ -662,24 +660,18 @@ def get_native_group_norm_inputs():
662660
]
663661
)
664662
test_suite.layouts = [
665-
# "utils::kWidthPacked",
666-
# "utils::kHeightPacked",
667663
"utils::kChannelsPacked",
668664
]
669665
test_suite.storage_types = [
670-
# "utils::kBuffer",
671-
"utils::kTexture3D"
666+
"utils::kTexture3D",
667+
]
668+
test_suite.dtypes = [
669+
"at::kFloat",
672670
]
673-
test_suite.dtypes = ["at::kFloat"]
674671
test_suite.arg_storage_types = {
675672
"out": [None, "utils::kBuffer", "utils::kBuffer"],
676673
}
677674

678-
# test_suite.arg_data_gen_fn = {
679-
# "weight": "make_zeros_tensor",
680-
# "bias": "make_rand_tensor",
681-
# }
682-
683675
test_suite.prepacked_args = ["weight", "bias"]
684676
test_suite.requires_prepack = True
685677

@@ -798,7 +790,6 @@ def get_permute_inputs():
798790
]
799791
test_suite.dtypes = [
800792
"at::kFloat",
801-
"at::kInt",
802793
]
803794
return test_suite
804795

0 commit comments

Comments
 (0)