Skip to content

Commit 42d0baf

Browse files
committed
Move SymbolicMatcher to tir/analysis
1 parent fa905d2 commit 42d0baf

File tree

3 files changed

+219
-122
lines changed

3 files changed

+219
-122
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -27,133 +27,12 @@
2727
#include <unordered_map>
2828
#include <unordered_set>
2929

30+
#include "../../tir/analysis/symbolic_matcher.h"
3031
#include "../../tir/ir/functor_common.h"
3132

3233
namespace tvm {
3334
namespace tir {
3435

35-
// TODO(Siyuan): move it to somewhere under tir folder
36-
/*!
37-
* \brief Match symbolic vars according to the given PrimExpr, and update the var_remap.
38-
* Will throw errors if there is a mismatch.
39-
*/
40-
class SymbolicMatcher : ExprFunctor<void(const PrimExpr& n, const PrimExpr& other)> {
41-
public:
42-
explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map<tir::Var, PrimExpr>* var_remap)
43-
: analyzer_(analyzer), var_remap_(var_remap) {}
44-
45-
void Match(const ffi::Array<PrimExpr>& params, const ffi::Array<PrimExpr>& args) {
46-
CHECK_EQ(params.size(), args.size());
47-
for (size_t i = 0; i < params.size(); ++i) {
48-
Match(params[i], args[i]);
49-
}
50-
}
51-
void Match(const PrimExpr& param, const PrimExpr& arg) {
52-
VisitExpr(param, arg);
53-
must_prove_ = analyzer_->Simplify(Substitute(must_prove_, *var_remap_));
54-
CHECK(!is_zero(must_prove_));
55-
}
56-
57-
private:
58-
void VisitExpr(const PrimExpr& node, const PrimExpr& other) {
59-
if (node.same_as(other)) {
60-
return;
61-
} else if (node.dtype().code() != other.dtype().code()) {
62-
LOG(FATAL) << "Parameter expression " << node << " with dtype " << node.dtype()
63-
<< " cannot match to argument " << other << " with dtype " << other.dtype();
64-
} else {
65-
ExprFunctor::VisitExpr(node, other);
66-
}
67-
}
68-
69-
#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \
70-
void VisitExpr_(const OpName* op, const PrimExpr& other) { \
71-
const auto* rhs = other.as<OpName>(); \
72-
if (rhs) { \
73-
VisitExpr(op->a, rhs->a); \
74-
VisitExpr(op->b, rhs->b); \
75-
} else { \
76-
must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other); \
77-
} \
78-
}
79-
80-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode);
81-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode);
82-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode);
83-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode);
84-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode);
85-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode);
86-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode);
87-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode);
88-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode);
89-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode);
90-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode);
91-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode);
92-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode);
93-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode);
94-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode);
95-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode);
96-
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode);
97-
98-
void VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
99-
const auto* rhs = other.as<IntImmNode>();
100-
if (!rhs || (op->value != rhs->value)) {
101-
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op)
102-
<< " expected an integer argument with value " << op->value << ", "
103-
<< "but was provided with the argument " << other;
104-
}
105-
}
106-
107-
void VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
108-
const auto* rhs = other.as<FloatImmNode>();
109-
if (!rhs || (op->value != rhs->value)) {
110-
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op)
111-
<< " expected an float argument with value " << op->value << ", "
112-
<< "but was provided with the argument " << other;
113-
}
114-
}
115-
116-
void VisitExpr_(const CastNode* op, const PrimExpr& other) {
117-
const auto* rhs = other.as<CastNode>();
118-
if (!rhs) {
119-
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " expected an cast to "
120-
<< op->dtype << " as the argument, "
121-
<< "but was provided with the argument " << other;
122-
}
123-
VisitExpr(op->value, rhs->value);
124-
}
125-
126-
void VisitExpr_(const VarNode* op, const PrimExpr& rhs) {
127-
auto lhs = ffi::GetRef<Var>(op);
128-
129-
if (lhs.same_as(rhs)) {
130-
// Reference identity, no further checks needed.
131-
} else if (op->dtype.code() != rhs->dtype.code()) {
132-
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " with dtype "
133-
<< op->dtype << " cannot match to argument " << rhs << " with dtype "
134-
<< rhs.dtype();
135-
} else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) {
136-
VisitExpr((*it).second, rhs);
137-
} else {
138-
var_remap_->Set(lhs, rhs);
139-
}
140-
}
141-
142-
void VisitExpr_(const SelectNode* op, const PrimExpr& other) {
143-
const auto* rhs = other.as<SelectNode>();
144-
if (rhs) {
145-
VisitExpr(op->true_value, rhs->true_value);
146-
VisitExpr(op->false_value, rhs->false_value);
147-
} else {
148-
must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other);
149-
}
150-
}
151-
152-
arith::Analyzer* analyzer_;
153-
ffi::Map<tir::Var, PrimExpr>* var_remap_;
154-
PrimExpr must_prove_ = Bool(true);
155-
};
156-
15736
/*!
15837
* \brief Substitute a given source buffer with a given target buffer in statements or expressions.
15938
*/
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "symbolic_matcher.h"
21+
22+
#include <tvm/arith/analyzer.h>
23+
#include <tvm/tir/op.h>
24+
#include <tvm/tir/stmt_functor.h>
25+
26+
namespace tvm {
27+
namespace tir {
28+
29+
void SymbolicMatcher::Match(const ffi::Array<PrimExpr>& params, const ffi::Array<PrimExpr>& args) {
30+
CHECK_EQ(params.size(), args.size());
31+
for (size_t i = 0; i < params.size(); ++i) {
32+
Match(params[i], args[i]);
33+
}
34+
}
35+
36+
void SymbolicMatcher::Match(const PrimExpr& param, const PrimExpr& arg) {
37+
VisitExpr(param, arg);
38+
must_prove_ = analyzer_->Simplify(Substitute(must_prove_, *var_remap_));
39+
CHECK(!is_zero(must_prove_));
40+
}
41+
42+
void SymbolicMatcher::VisitExpr(const PrimExpr& node, const PrimExpr& other) {
43+
if (node.same_as(other)) {
44+
return;
45+
} else if (node.dtype().code() != other.dtype().code()) {
46+
LOG(FATAL) << "Parameter expression " << node << " with dtype " << node.dtype()
47+
<< " cannot match to argument " << other << " with dtype " << other.dtype();
48+
} else {
49+
ExprFunctor::VisitExpr(node, other);
50+
}
51+
}
52+
53+
#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \
54+
void SymbolicMatcher::VisitExpr_(const OpName* op, const PrimExpr& other) { \
55+
const auto* rhs = other.as<OpName>(); \
56+
if (rhs) { \
57+
VisitExpr(op->a, rhs->a); \
58+
VisitExpr(op->b, rhs->b); \
59+
} else { \
60+
must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other); \
61+
} \
62+
}
63+
64+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode);
65+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode);
66+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode);
67+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode);
68+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode);
69+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode);
70+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode);
71+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode);
72+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode);
73+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode);
74+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode);
75+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode);
76+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode);
77+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode);
78+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode);
79+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode);
80+
TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode);
81+
82+
void SymbolicMatcher::VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
83+
const auto* rhs = other.as<IntImmNode>();
84+
if (!rhs || (op->value != rhs->value)) {
85+
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op)
86+
<< " expected an integer argument with value " << op->value << ", "
87+
<< "but was provided with the argument " << other;
88+
}
89+
}
90+
91+
void SymbolicMatcher::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
92+
const auto* rhs = other.as<FloatImmNode>();
93+
if (!rhs || (op->value != rhs->value)) {
94+
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op)
95+
<< " expected an float argument with value " << op->value << ", "
96+
<< "but was provided with the argument " << other;
97+
}
98+
}
99+
100+
void SymbolicMatcher::VisitExpr_(const CastNode* op, const PrimExpr& other) {
101+
const auto* rhs = other.as<CastNode>();
102+
if (!rhs) {
103+
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " expected an cast to "
104+
<< op->dtype << " as the argument, "
105+
<< "but was provided with the argument " << other;
106+
}
107+
VisitExpr(op->value, rhs->value);
108+
}
109+
110+
void SymbolicMatcher::VisitExpr_(const VarNode* op, const PrimExpr& rhs) {
111+
auto lhs = ffi::GetRef<Var>(op);
112+
113+
if (lhs.same_as(rhs)) {
114+
// Reference identity, no further checks needed.
115+
} else if (op->dtype.code() != rhs->dtype.code()) {
116+
LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " with dtype "
117+
<< op->dtype << " cannot match to argument " << rhs << " with dtype " << rhs.dtype();
118+
} else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) {
119+
VisitExpr((*it).second, rhs);
120+
} else {
121+
var_remap_->Set(lhs, rhs);
122+
}
123+
}
124+
125+
void SymbolicMatcher::VisitExpr_(const SelectNode* op, const PrimExpr& other) {
126+
const auto* rhs = other.as<SelectNode>();
127+
if (rhs) {
128+
VisitExpr(op->condition, rhs->condition);
129+
VisitExpr(op->true_value, rhs->true_value);
130+
VisitExpr(op->false_value, rhs->false_value);
131+
} else {
132+
must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other);
133+
}
134+
}
135+
136+
} // namespace tir
137+
} // namespace tvm
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_
21+
#define TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_
22+
23+
#include <tvm/tir/expr.h>
24+
#include <tvm/tir/expr_functor.h>
25+
26+
namespace tvm {
27+
28+
namespace arith {
29+
class Analyzer;
30+
}
31+
32+
namespace tir {
33+
34+
/*!
35+
* \brief Match symbolic vars according to the given PrimExpr, and update the var_remap.
36+
* Will throw errors if there is a mismatch.
37+
*/
38+
class SymbolicMatcher : ExprFunctor<void(const PrimExpr& n, const PrimExpr& other)> {
39+
public:
40+
explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map<tir::Var, PrimExpr>* var_remap)
41+
: analyzer_(analyzer), var_remap_(var_remap) {}
42+
43+
void Match(const ffi::Array<PrimExpr>& params, const ffi::Array<PrimExpr>& args);
44+
void Match(const PrimExpr& param, const PrimExpr& arg);
45+
46+
private:
47+
void VisitExpr(const PrimExpr& node, const PrimExpr& other);
48+
49+
void VisitExpr_(const AddNode* op, const PrimExpr& other) final;
50+
void VisitExpr_(const SubNode* op, const PrimExpr& other) final;
51+
void VisitExpr_(const MulNode* op, const PrimExpr& other) final;
52+
void VisitExpr_(const DivNode* op, const PrimExpr& other) final;
53+
void VisitExpr_(const ModNode* op, const PrimExpr& other) final;
54+
void VisitExpr_(const EQNode* op, const PrimExpr& other) final;
55+
void VisitExpr_(const NENode* op, const PrimExpr& other) final;
56+
void VisitExpr_(const LTNode* op, const PrimExpr& other) final;
57+
void VisitExpr_(const LENode* op, const PrimExpr& other) final;
58+
void VisitExpr_(const GTNode* op, const PrimExpr& other) final;
59+
void VisitExpr_(const GENode* op, const PrimExpr& other) final;
60+
void VisitExpr_(const AndNode* op, const PrimExpr& other) final;
61+
void VisitExpr_(const OrNode* op, const PrimExpr& other) final;
62+
void VisitExpr_(const MinNode* op, const PrimExpr& other) final;
63+
void VisitExpr_(const MaxNode* op, const PrimExpr& other) final;
64+
void VisitExpr_(const FloorDivNode* op, const PrimExpr& other) final;
65+
void VisitExpr_(const FloorModNode* op, const PrimExpr& other) final;
66+
67+
void VisitExpr_(const IntImmNode* op, const PrimExpr& other) final;
68+
void VisitExpr_(const FloatImmNode* op, const PrimExpr& other) final;
69+
void VisitExpr_(const CastNode* op, const PrimExpr& other) final;
70+
void VisitExpr_(const VarNode* op, const PrimExpr& rhs) final;
71+
void VisitExpr_(const SelectNode* op, const PrimExpr& other) final;
72+
73+
arith::Analyzer* analyzer_;
74+
ffi::Map<tir::Var, PrimExpr>* var_remap_;
75+
PrimExpr must_prove_ = Bool(true);
76+
};
77+
78+
} // namespace tir
79+
} // namespace tvm
80+
81+
#endif // TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_

0 commit comments

Comments
 (0)