Skip to content

Commit faab2e7

Browse files
[Relax] Fix the squeeze operator to behave consistently with torch (#18478)
This commit fixes the squeeze operator to behave consistently with PyTorch by implementing no-op behavior when squeezing dimensions that are not of size 1. Previously: squeeze(x, [1]) on tensor with shape [32, 10, 5] would fail Now: squeeze(x, [1]) on tensor with shape [32, 10, 5] returns the original tensor without modification, matching PyTorch's behavior This fixes compatibility issues when converting PyTorch models that use squeeze with dimensions that may not always be 1 during inference." This work was done in collaboration with guan404ming's commit d87841d.
1 parent 97d78aa commit faab2e7

File tree

5 files changed

+39
-18
lines changed

5 files changed

+39
-18
lines changed

include/tvm/topi/transform.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,11 @@ inline Tensor squeeze(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_ax
428428
if (val < 0) {
429429
val += static_cast<int>(x->shape.size());
430430
}
431-
if (IsConstInt(x->shape[val])) {
432-
ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
431+
// If a dimension is not 1, silently skip it (no-op).
432+
bool is_const = IsConstInt(x->shape[val]);
433+
if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
434+
axis_val.push_back(val);
433435
}
434-
axis_val.push_back(val);
435436
}
436437
}
437438

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,7 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
20032003
valid_dims = []
20042004
for d in dim:
20052005
axis = d if d >= 0 else len(shape) + d
2006-
if axis < len(shape) and shape[axis] == 1:
2006+
if axis < len(shape):
20072007
valid_dims.append(d)
20082008
# If no valid dims, use None to squeeze all size-1 dimensions
20092009
dim = valid_dims if valid_dims else None

src/relax/op/tensor/manipulate.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,15 +1234,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
12341234
// Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1.
12351235
// When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic
12361236
const auto* int_len = shape_value.value()[axes[i]].as<IntImmNode>();
1237-
if (int_len != nullptr && int_len->value != 1) {
1238-
ctx->ReportFatal(Diagnostic::Error(call)
1239-
<< "Squeeze expects the input tensor shape values at the given axis "
1240-
"positions to be all 1. However, the tensor shape at axis "
1241-
<< axes[i] << " is " << shape_value.value()[axes[i]]
1242-
<< " which is not 1. If it is symbolic, please use MatchCast to cast it "
1243-
"to 1 before doing Squeeze.");
1237+
// If a dimension is not 1, silently skip it (no-op), matching PyTorch behavior.
1238+
if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) {
1239+
axis_removal_mask[axes[i]] = true;
12441240
}
1245-
axis_removal_mask[axes[i]] = true;
12461241
}
12471242
} else {
12481243
// When `axis` is not defined, squeeze all unit-length dimensions.

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5482,15 +5482,32 @@ def main(
54825482
input: R.Tensor((3, 1, 4, 1), dtype="float32")
54835483
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
54845484
with R.dataflow():
5485-
lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[1, 3])
5485+
lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[0, 1, 2, 3])
54865486
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
54875487
R.output(gv)
54885488
return gv
54895489

5490+
class Squeeze3(Module):
5491+
def forward(self, input):
5492+
return input.squeeze(2)
5493+
5494+
@I.ir_module
5495+
class Expected3:
5496+
@R.function
5497+
def main(
5498+
inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
5499+
) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")):
5500+
with R.dataflow():
5501+
lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[2])
5502+
gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,)
5503+
R.output(gv)
5504+
return gv
5505+
54905506
example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
54915507

54925508
verify_model(Squeeze1(), example_args, {}, Expected1)
54935509
verify_model(Squeeze2(), example_args, {}, Expected2)
5510+
verify_model(Squeeze3(), example_args, {}, Expected3)
54945511

54955512

54965513
def test_stack():

tests/python/relax/test_op_manipulate.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one():
994994
x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
995995
x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
996996

997-
with pytest.raises(TVMError):
998-
bb.normalize(relax.op.squeeze(x0, [0]))
999-
_check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32"))
1000-
with pytest.raises(TVMError):
1001-
bb.normalize(relax.op.squeeze(x2, [0]))
997+
# Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, squeeze is no-op.
998+
_check_inference(
999+
bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), dtype="float32")
1000+
)
1001+
# Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve successful squeeze.
1002+
_check_inference(
1003+
bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), dtype="float32")
1004+
)
1005+
# Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0.
1006+
_check_inference(
1007+
bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, dtype="float32")
1008+
)
1009+
# Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve successful squeeze.
10021010
_check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2))
10031011

10041012

0 commit comments

Comments
 (0)