Skip to content

Commit 95466f1

Browse files
committed
Add test case: test_composite_shape_expression
1 parent 7593b4d commit 95466f1

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/python/relax/test_backend_transform_shape_lower.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,5 +893,32 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
893893
assert_structural_equal(Expected, After)
894894

895895

896+
def test_composite_shape_expression():
897+
"""When a ShapeExpr contains composite PrimExpr that haven't been computed yet,
898+
VMShapeLower should trigger computation before processing the shape.
899+
"""
900+
901+
@tvm.script.ir_module
902+
class Before:
903+
@R.function
904+
def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")) -> R.Tensor:
905+
R.func_attr({"relax.force_pure": True})
906+
x_0 = T.int64()
907+
x_1 = T.int64()
908+
x_2 = T.int64()
909+
x_3 = T.int64()
910+
# This creates a composite expression that was causing the crash:
911+
# T.int64(4) * (x_0 * x_1 * x_2 * x_3)
912+
new_shape = R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)])
913+
return R.reshape(x, new_shape)
914+
915+
# The test shoud not crash during VMShapeLower
916+
# We don't need to validate teh exact output, just that it doesn't crash
917+
after = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)
918+
919+
# The actual output structure is not as important as not crashing
920+
assert after is not None
921+
922+
896923
if __name__ == "__main__":
897924
tvm.testing.main()

0 commit comments

Comments
 (0)