Skip to content

Commit 7b9d3d9

Browse files
authored
[Relax] Move GetUsedVars to analysis module (#18632)
## Why The GetUsedVars function was defined locally in binding_rewrite.cc with a TODO comment suggesting it should be moved to the analysis module. This refactoring improves code organization by placing the utility function alongside other variable analysis functions. ## How - Move GetUsedVars implementation to analysis module - Add FFI registration and Python wrapper - Add parametrized test
1 parent a393b47 commit 7b9d3d9

File tree

6 files changed

+77
-12
lines changed

6 files changed

+77
-12
lines changed

include/tvm/relax/analysis.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <tvm/tir/index_map.h>
3535

3636
#include <functional>
37+
#include <set>
3738
#include <utility>
3839

3940
namespace tvm {
@@ -494,6 +495,17 @@ struct VarUsageInfo {
494495
*/
495496
VarUsageInfo CollectVarUsage(const Expr& expr);
496497

498+
/*!
499+
* \brief Get the used variables in an expression.
500+
*
501+
* This function collects all variables that are referenced within the given expression.
502+
*
503+
* \param expr The expression to analyze
504+
*
505+
* \return A set of variable nodes that are used in the expression
506+
*/
507+
TVM_DLL std::set<const VarNode*> GetUsedVars(const Expr& expr);
508+
497509
/*!
498510
* \brief Remove unused statements inside DataflowBlocks.
499511
*

python/tvm/relax/analysis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
free_symbolic_vars,
3333
free_vars,
3434
get_static_type,
35+
used_vars,
3536
get_var2val,
3637
has_reshape_pattern,
3738
name_to_binding,

python/tvm/relax/analysis/analysis.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,26 @@ def all_vars(expr: Expr) -> List[Var]:
312312
return _ffi_api.all_vars(expr)
313313

314314

315+
def used_vars(expr: Expr) -> List[Var]:
316+
"""
317+
Return all variables used in an expression.
318+
319+
This function collects all variable references within the given expression,
320+
which is useful for analyzing variable dependencies.
321+
322+
Parameters
323+
----------
324+
expr: Expr
325+
The expression to analyze.
326+
327+
Returns
328+
-------
329+
ret: List[Var]
330+
List of variables used in the expression.
331+
"""
332+
return _ffi_api.used_vars(expr) # type: ignore
333+
334+
315335
def all_global_vars(expr: Expr) -> List[GlobalVar]:
316336
"""
317337
Return all global variables from expression expr.

src/relax/analysis/udchain.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,29 @@ ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb) {
121121

122122
TVM_FFI_STATIC_INIT_BLOCK() {
123123
namespace refl = tvm::ffi::reflection;
124-
refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef);
124+
refl::GlobalDef()
125+
.def("relax.analysis.udchain", DataflowBlockUseDef)
126+
.def("relax.analysis.used_vars", [](const Expr& expr) {
127+
auto used_vars = GetUsedVars(expr);
128+
ffi::Array<Var> result;
129+
for (const VarNode* var_node : used_vars) {
130+
result.push_back(ffi::GetRef<Var>(var_node));
131+
}
132+
return result;
133+
});
125134
}
126135

127136
VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); }
128137

138+
std::set<const VarNode*> GetUsedVars(const Expr& expr) {
139+
class UsedVars : public ExprVisitor {
140+
public:
141+
std::set<const VarNode*> used_vars;
142+
void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
143+
} visitor;
144+
visitor.VisitExpr(expr);
145+
return std::move(visitor.used_vars);
146+
}
147+
129148
} // namespace relax
130149
} // namespace tvm

src/relax/ir/binding_rewrite.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*/
2424

2525
#include <tvm/ffi/reflection/registry.h>
26+
#include <tvm/relax/analysis.h>
2627
#include <tvm/relax/binding_rewrite.h>
2728
#include <tvm/relax/block_builder.h>
2829
#include <tvm/relax/expr.h>
@@ -134,17 +135,6 @@ class UpdateDFB : public ExprMutator {
134135
}
135136
};
136137

137-
// TODO(masahi): Consider moving this to analysis
138-
std::set<const VarNode*> GetUsedVars(Expr val) {
139-
class UsedVars : public ExprVisitor {
140-
public:
141-
std::set<const VarNode*> used_vars;
142-
void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
143-
} uvar{};
144-
uvar.VisitExpr(val);
145-
return std::move(uvar.used_vars);
146-
}
147-
148138
void DataflowBlockRewriteNode::Add(Binding binding) {
149139
auto [var, val] = [binding] {
150140
if (auto vb = binding.as<VarBindingNode>()) {

tests/python/relax/test_analysis.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from typing import List, Set, Union
1919

20+
import pytest
2021
import tvm
2122
import tvm.testing
2223
from tvm import relax as rx
@@ -26,6 +27,7 @@
2627
all_vars,
2728
bound_vars,
2829
free_vars,
30+
used_vars,
2931
has_reshape_pattern,
3032
name_to_binding,
3133
remove_all_unused,
@@ -61,6 +63,27 @@ def test_use_def():
6163
assert set(udc[gv0]) == set()
6264

6365

66+
@pytest.mark.parametrize(
67+
"expr_fn, expected_var_names",
68+
[
69+
(lambda x, y, z: rx.op.add(x, y), {"x", "y"}),
70+
(lambda x, y, z: rx.op.multiply(x, x), {"x"}),
71+
(lambda x, y, z: rx.Tuple([x, y, z]), {"x", "y", "z"}),
72+
],
73+
ids=["binary_op", "self_reference", "tuple"],
74+
)
75+
def test_used_vars(expr_fn, expected_var_names):
76+
m = tir.Var("m", "int64")
77+
n = tir.Var("n", "int64")
78+
x = rx.Var("x", R.Tensor([m, n], "float16"))
79+
y = rx.Var("y", R.Tensor([n], "float16"))
80+
z = rx.Var("z", R.Tensor([m], "float16"))
81+
82+
expr = expr_fn(x, y, z)
83+
result = used_vars(expr)
84+
assert var_name_set(result) == expected_var_names
85+
86+
6487
def test_chained_remove_all_unused():
6588
@tvm.script.ir_module
6689
class IdentityUnused:

0 commit comments

Comments
 (0)