diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index bbc227d1d559..3b192700e3ec 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -399,6 +399,23 @@ class VMShapeLowerMutator return ffi::GetRef(op); } + // Check if all expressions are computed if not mark variables as ready and trigger computation + for (const PrimExpr& expr : op->values) { + if (!expr->IsInstance()) { + auto it = slot_map_.find(expr); + if (it != slot_map_.end() && !it->second->value_computed) { + // If it's a variable, mark it as ready for computation + if (expr.as()) { + it->second->value_computed = true; + ready_vars_.push_back(it->second); + } + } + } + } + + // Trigger computation for any expression that are now ready + this->EmitOutstandingPrimExprCompute(); + ffi::Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; for (PrimExpr expr : op->values) { diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 177d036107c1..277f8de638ff 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -893,5 +893,32 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): assert_structural_equal(Expected, After) +def test_composite_shape_expression(): + """When a ShapeExpr contains composite PrimExpr that haven't been computed yet, + VMShapeLower should trigger computation before processing the shape. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")) -> R.Tensor: + R.func_attr({"relax.force_pure": True}) + x_0 = T.int64() + x_1 = T.int64() + x_2 = T.int64() + x_3 = T.int64() + # This creates a composite expression that was causing the crash: + # T.int64(4) * (x_0 * x_1 * x_2 * x_3) + new_shape = R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]) + return R.reshape(x, new_shape) + + # The test shoud not crash during VMShapeLower + # We don't need to validate teh exact output, just that it doesn't crash + after = relax.transform.VMShapeLower(emit_err_ctx=False)(Before) + + # The actual output structure is not as important as not crashing + assert after is not None + + if __name__ == "__main__": tvm.testing.main()