From ac6abefe6243caa4b2069e5bec299c4681c62581 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Tue, 27 May 2025 20:32:14 -0700 Subject: [PATCH] [ET-VK] Modifying should_squeeze function in SqueezeUnsqueezeInputs to not squeeze if significant axis are all 1 and trailing axis are all > 1. This diff modifies the `should_squeeze` function in `SqueezeUnsqueezeInputs` to not squeeze (return False) if significant axes are all 1 and trailing axes are all > 1. Differential Revision: [D75483587](https://our.internmc.facebook.com/intern/diff/D75483587/) [ghstack-poisoned] --- backends/vulkan/_passes/squeeze_unsqueeze_inputs.py | 8 +++++++- backends/vulkan/vulkan_preprocess.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index b4337829d7f..c415249383e 100644 --- a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -32,7 +32,13 @@ def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore return shape[1] == 1 and shape[0] > 1 if len(shape) == 4: # No need to squeeze if all dims are 1 except the width dim - if all(dim == 1 for dim in shape[:-1]): + if shape[0] == shape[1] == shape[2] == 1: + return False + # No need to squeeze if batch and channel dims are 1 and height and width are > 1 + if shape[0] == shape[1] == 1 and shape[2] > 1 and shape[3] > 1: + return False + # No need to squeeze if batch dim is 1 and channel, height and width are > 1 + if shape[0] == 1 and shape[1] > 1 and shape[2] > 1 and shape[3] > 1: return False # Otherwise, check for squeezable dim return 1 in shape[:-1] diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 4200df3e131..bb3420be50e 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -150,7 +150,6 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), SqueezeUnsqueezeInputs(), @@ -158,6 +157,7 @@ def preprocess( # noqa: C901 ViewCopyToSqueezeUnsqueezePass(), FuseBatchNormWithConvPass(program), FuseClampPass(), + RemoveRedundantOpsTransform(), ], )