diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 549cd2197b4b..eb860b91e1c9 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -27,133 +27,12 @@ #include #include +#include "../../tir/analysis/symbolic_matcher.h" #include "../../tir/ir/functor_common.h" namespace tvm { namespace tir { -// TODO(Siyuan): move it to somewhere under tir folder -/*! - * \brief Match symbolic vars according to the given PrimExpr, and update the var_remap. - * Will throw errors if there is a mismatch. - */ -class SymbolicMatcher : ExprFunctor { - public: - explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) - : analyzer_(analyzer), var_remap_(var_remap) {} - - void Match(const ffi::Array& params, const ffi::Array& args) { - CHECK_EQ(params.size(), args.size()); - for (size_t i = 0; i < params.size(); ++i) { - Match(params[i], args[i]); - } - } - void Match(const PrimExpr& param, const PrimExpr& arg) { - VisitExpr(param, arg); - must_prove_ = analyzer_->Simplify(Substitute(must_prove_, *var_remap_)); - CHECK(!is_zero(must_prove_)); - } - - private: - void VisitExpr(const PrimExpr& node, const PrimExpr& other) { - if (node.same_as(other)) { - return; - } else if (node.dtype().code() != other.dtype().code()) { - LOG(FATAL) << "Parameter expression " << node << " with dtype " << node.dtype() - << " cannot match to argument " << other << " with dtype " << other.dtype(); - } else { - ExprFunctor::VisitExpr(node, other); - } - } - -#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ - void VisitExpr_(const OpName* op, const PrimExpr& other) { \ - const auto* rhs = other.as(); \ - if (rhs) { \ - VisitExpr(op->a, rhs->a); \ - VisitExpr(op->b, rhs->b); \ - } else { \ - must_prove_ = must_prove_ && (ffi::GetRef(op) == other); \ - } \ - } - - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode); - TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode); - - void VisitExpr_(const IntImmNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) - << " expected an integer argument with value " << op->value << ", " - << "but was provided with the argument " << other; - } - } - - void VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) - << " expected an float argument with value " << op->value << ", " - << "but was provided with the argument " << other; - } - } - - void VisitExpr_(const CastNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - if (!rhs) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " - << op->dtype << " as the argument, " - << "but was provided with the argument " << other; - } - VisitExpr(op->value, rhs->value); - } - - void VisitExpr_(const VarNode* op, const PrimExpr& rhs) { - auto lhs = ffi::GetRef(op); - - if (lhs.same_as(rhs)) { - // Reference identity, no further checks needed. - } else if (op->dtype.code() != rhs->dtype.code()) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " with dtype " - << op->dtype << " cannot match to argument " << rhs << " with dtype " - << rhs.dtype(); - } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { - VisitExpr((*it).second, rhs); - } else { - var_remap_->Set(lhs, rhs); - } - } - - void VisitExpr_(const SelectNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - if (rhs) { - VisitExpr(op->true_value, rhs->true_value); - VisitExpr(op->false_value, rhs->false_value); - } else { - must_prove_ = must_prove_ && (ffi::GetRef(op) == other); - } - } - - arith::Analyzer* analyzer_; - ffi::Map* var_remap_; - PrimExpr must_prove_ = Bool(true); -}; - /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions. */ diff --git a/src/tir/analysis/symbolic_matcher.cc b/src/tir/analysis/symbolic_matcher.cc new file mode 100644 index 000000000000..39e92c0a6106 --- /dev/null +++ b/src/tir/analysis/symbolic_matcher.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "symbolic_matcher.h" + +#include +#include +#include + +namespace tvm { +namespace tir { + +void SymbolicMatcher::Match(const ffi::Array& params, const ffi::Array& args) { + CHECK_EQ(params.size(), args.size()); + for (size_t i = 0; i < params.size(); ++i) { + Match(params[i], args[i]); + } +} + +void SymbolicMatcher::Match(const PrimExpr& param, const PrimExpr& arg) { + VisitExpr(param, arg); + must_prove_ = analyzer_->Simplify(Substitute(must_prove_, *var_remap_)); + CHECK(!is_zero(must_prove_)); +} + +void SymbolicMatcher::VisitExpr(const PrimExpr& node, const PrimExpr& other) { + if (node.same_as(other)) { + return; + } else if (node.dtype().code() != other.dtype().code()) { + LOG(FATAL) << "Parameter expression " << node << " with dtype " << node.dtype() + << " cannot match to argument " << other << " with dtype " << other.dtype(); + } else { + ExprFunctor::VisitExpr(node, other); + } +} + +#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ + void SymbolicMatcher::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + if (rhs) { \ + VisitExpr(op->a, rhs->a); \ + VisitExpr(op->b, rhs->b); \ + } else { \ + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); \ + } \ + } + +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode); +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode); + +void SymbolicMatcher::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (!rhs || (op->value != rhs->value)) { + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) + << " expected an integer argument with value " << op->value << ", " + << "but was provided with the argument " << other; + } +} + +void SymbolicMatcher::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (!rhs || (op->value != rhs->value)) { + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) + << " expected an float argument with value " << op->value << ", " + << "but was provided with the argument " << other; + } +} + +void SymbolicMatcher::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (!rhs) { + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " + << op->dtype << " as the argument, " + << "but was provided with the argument " << other; + } + VisitExpr(op->value, rhs->value); +} + +void SymbolicMatcher::VisitExpr_(const VarNode* op, const PrimExpr& rhs) { + auto lhs = ffi::GetRef(op); + + if (lhs.same_as(rhs)) { + // Reference identity, no further checks needed. + } else if (op->dtype.code() != rhs->dtype.code()) { + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " with dtype " + << op->dtype << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); + } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { + VisitExpr((*it).second, rhs); + } else { + var_remap_->Set(lhs, rhs); + } +} + +void SymbolicMatcher::VisitExpr_(const SelectNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (rhs) { + VisitExpr(op->condition, rhs->condition); + VisitExpr(op->true_value, rhs->true_value); + VisitExpr(op->false_value, rhs->false_value); + } else { + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/symbolic_matcher.h b/src/tir/analysis/symbolic_matcher.h new file mode 100644 index 000000000000..7deae5abb715 --- /dev/null +++ b/src/tir/analysis/symbolic_matcher.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_ +#define TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_ + +#include +#include + +namespace tvm { + +namespace arith { +class Analyzer; +} + +namespace tir { + +/*! + * \brief Match symbolic vars according to the given PrimExpr, and update the var_remap. + * Will throw errors if there is a mismatch. + */ +class SymbolicMatcher : ExprFunctor { + public: + explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) + : analyzer_(analyzer), var_remap_(var_remap) {} + + void Match(const ffi::Array& params, const ffi::Array& args); + void Match(const PrimExpr& param, const PrimExpr& arg); + + private: + void VisitExpr(const PrimExpr& node, const PrimExpr& other); + + void VisitExpr_(const AddNode* op, const PrimExpr& other) final; + void VisitExpr_(const SubNode* op, const PrimExpr& other) final; + void VisitExpr_(const MulNode* op, const PrimExpr& other) final; + void VisitExpr_(const DivNode* op, const PrimExpr& other) final; + void VisitExpr_(const ModNode* op, const PrimExpr& other) final; + void VisitExpr_(const EQNode* op, const PrimExpr& other) final; + void VisitExpr_(const NENode* op, const PrimExpr& other) final; + void VisitExpr_(const LTNode* op, const PrimExpr& other) final; + void VisitExpr_(const LENode* op, const PrimExpr& other) final; + void VisitExpr_(const GTNode* op, const PrimExpr& other) final; + void VisitExpr_(const GENode* op, const PrimExpr& other) final; + void VisitExpr_(const AndNode* op, const PrimExpr& other) final; + void VisitExpr_(const OrNode* op, const PrimExpr& other) final; + void VisitExpr_(const MinNode* op, const PrimExpr& other) final; + void VisitExpr_(const MaxNode* op, const PrimExpr& other) final; + void VisitExpr_(const FloorDivNode* op, const PrimExpr& other) final; + void VisitExpr_(const FloorModNode* op, const PrimExpr& other) final; + + void VisitExpr_(const IntImmNode* op, const PrimExpr& other) final; + void VisitExpr_(const FloatImmNode* op, const PrimExpr& other) final; + void VisitExpr_(const CastNode* op, const PrimExpr& other) final; + void VisitExpr_(const VarNode* op, const PrimExpr& rhs) final; + void VisitExpr_(const SelectNode* op, const PrimExpr& other) final; + + arith::Analyzer* analyzer_; + ffi::Map* var_remap_; + PrimExpr must_prove_ = Bool(true); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_SYMBOLIC_MATCHER_H_