Skip to content

Commit 5a6dce3

Browse files
committed
Update on "[ET-VK] Implement generic reduction shader + mean, sum, amax, amin"
## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) [ghstack-poisoned]
2 parents f63fa21 + 25ef970 commit 5a6dce3

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

backends/vulkan/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ runtime.python_library(
2727
"//executorch/backends/transforms:fuse_conv_with_clamp",
2828
"//executorch/backends/transforms:fuse_dequant_linear",
2929
"//executorch/backends/transforms:fuse_view_copy",
30-
"//executorch/backends/transforms:mean_to_sum_div",
3130
"//executorch/backends/transforms:remove_clone_ops",
3231
"//executorch/backends/vulkan/_passes:vulkan_passes",
3332
"//executorch/exir:graph_module",

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,9 @@ def forward(self, x):
725725

726726
self.lower_module_and_test_output(module, sample_inputs)
727727

728+
@unittest.skip(
729+
"Reduce shader does not support multiple reduction axes at the moment"
730+
)
728731
def test_vulkan_backend_sum_dim_list(self):
729732
class SumModule(torch.nn.Module):
730733
def __init__(self):
@@ -744,6 +747,9 @@ def forward(self, x):
744747
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
745748
)
746749

750+
@unittest.skip(
751+
"Reduce shader does not support multiple reduction axes at the moment"
752+
)
747753
def test_vulkan_backend_sum(self):
748754
class SumModule(torch.nn.Module):
749755
def __init__(self):
@@ -1441,6 +1447,9 @@ def forward(self, x):
14411447

14421448
self.lower_unary_module_and_test_output(GeluModule())
14431449

1450+
@unittest.skip(
1451+
"Reduce shader does not support multiple reduction axes at the moment"
1452+
)
14441453
def test_vulkan_backend_mean(self):
14451454
class MeanModule(torch.nn.Module):
14461455
def __init__(self, dims, keepdim=True):

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1616
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1717
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
18-
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1918
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2019

2120
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
@@ -65,7 +64,6 @@ def preprocess( # noqa: C901
6564
FuseViewCopyTransform(),
6665
FuseBatchNormWithConvPass(program),
6766
FuseClampPass(),
68-
MeanToSumDiv(),
6967
SpecPropPass(),
7068
ConstraintBasedSymShapeEvalPass(),
7169
RemoveLocalScalarDenseOpsTransform(),

0 commit comments

Comments
 (0)