Skip to content
Merged
8 changes: 7 additions & 1 deletion backends/vulkan/_passes/squeeze_unsqueeze_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
return shape[1] == 1 and shape[0] > 1
if len(shape) == 4:
# No need to squeeze if all dims are 1 except the width dim
if all(dim == 1 for dim in shape[:-1]):
if shape[0] == shape[1] == shape[2] == 1:
return False
# No need to squeeze if batch and channel dims are 1 and height and width are > 1
if shape[0] == shape[1] == 1 and shape[2] > 1 and shape[3] > 1:
return False
# No need to squeeze if batch dim is 1 and channel, height and width are > 1
if shape[0] == 1 and shape[1] > 1 and shape[2] > 1 and shape[3] > 1:
return False
# Otherwise, check for squeezable dim
return 1 in shape[:-1]
Expand Down
Loading