Skip to content

Commit bd70337

Browse files
SS-JIAssjia
andauthored
[ET-VK][ez] Use XNNPACK's FuseBatchNorm pass (#13600)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #13597 * #13596 * #13595 * #13594 * #13593 * __->__ #13600 * #13599 * #13598 As title. Use XNNPACK's FuseBatchNorm pass since it can fuse into linear layers as well. Differential Revision: [D80741735](https://our.internmc.facebook.com/intern/diff/D80741735/) Co-authored-by: ssjia <[email protected]>
1 parent 975a4a3 commit bd70337

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

backends/vulkan/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ def define_common_targets(is_fbcode = False):
387387
"//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze",
388388
"//executorch/backends/vulkan/_passes:vulkan_passes",
389389
"//executorch/backends/vulkan/serialization:lib",
390+
"//executorch/backends/transforms:remove_getitem_op",
391+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
390392
"//executorch/exir/backend:backend_details",
391393
],
392394
)

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
import executorch.backends.vulkan.utils as utils
1414

1515
from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
16-
from executorch.backends.transforms.fuse_batch_norm_with_conv import (
17-
FuseBatchNormWithConvPass,
18-
)
1916
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
2017
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
2118
from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import (
@@ -40,6 +37,7 @@
4037
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
4138
serialize_vulkan_graph,
4239
)
40+
from executorch.backends.xnnpack._passes import FuseBatchNormPass
4341

4442
from executorch.exir.backend.backend_details import (
4543
BackendDetails,
@@ -162,7 +160,7 @@ def preprocess( # noqa: C901
162160
SqueezeUnsqueezeInputs(),
163161
FuseViewCopyTransform(),
164162
ViewCopyToSqueezeUnsqueezePass(),
165-
FuseBatchNormWithConvPass(program),
163+
FuseBatchNormPass(program),
166164
FuseClampPass(),
167165
],
168166
)

0 commit comments

Comments
 (0)