Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,23 @@ class VMShapeLowerMutator
return ffi::GetRef<Expr>(op);
}

// Check if all expressions are computed if not mark variables as ready and trigger computation
for (const PrimExpr& expr : op->values) {
if (!expr->IsInstance<IntImmNode>()) {
auto it = slot_map_.find(expr);
if (it != slot_map_.end() && !it->second->value_computed) {
// If it's a variable, mark it as ready for computation
if (expr.as<tir::VarNode>()) {
it->second->value_computed = true;
ready_vars_.push_back(it->second);
}
}
}
}

// Trigger computation for any expression that are now ready
this->EmitOutstandingPrimExprCompute();

ffi::Array<Expr> args = {shape_heap_,
PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
for (PrimExpr expr : op->values) {
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,5 +893,32 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
assert_structural_equal(Expected, After)


def test_composite_shape_expression():
"""When a ShapeExpr contains composite PrimExpr that haven't been computed yet,
VMShapeLower should trigger computation before processing the shape.
"""

@tvm.script.ir_module
class Before:
@R.function
def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")) -> R.Tensor:
R.func_attr({"relax.force_pure": True})
x_0 = T.int64()
x_1 = T.int64()
x_2 = T.int64()
x_3 = T.int64()
# This creates a composite expression that was causing the crash:
# T.int64(4) * (x_0 * x_1 * x_2 * x_3)
new_shape = R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)])
return R.reshape(x, new_shape)

# The test shoud not crash during VMShapeLower
# We don't need to validate teh exact output, just that it doesn't crash
after = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)

# The actual output structure is not as important as not crashing
assert after is not None


if __name__ == "__main__":
tvm.testing.main()