|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include "symbolic_matcher.h" |
| 21 | + |
| 22 | +#include <tvm/arith/analyzer.h> |
| 23 | +#include <tvm/tir/op.h> |
| 24 | +#include <tvm/tir/stmt_functor.h> |
| 25 | + |
| 26 | +namespace tvm { |
| 27 | +namespace tir { |
| 28 | + |
| 29 | +void SymbolicMatcher::Match(const ffi::Array<PrimExpr>& params, const ffi::Array<PrimExpr>& args) { |
| 30 | + CHECK_EQ(params.size(), args.size()); |
| 31 | + for (size_t i = 0; i < params.size(); ++i) { |
| 32 | + Match(params[i], args[i]); |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +void SymbolicMatcher::Match(const PrimExpr& param, const PrimExpr& arg) { |
| 37 | + VisitExpr(param, arg); |
| 38 | + must_prove_ = analyzer_->Simplify(Substitute(must_prove_, *var_remap_)); |
| 39 | + CHECK(!is_zero(must_prove_)); |
| 40 | +} |
| 41 | + |
| 42 | +void SymbolicMatcher::VisitExpr(const PrimExpr& node, const PrimExpr& other) { |
| 43 | + if (node.same_as(other)) { |
| 44 | + return; |
| 45 | + } else if (node.dtype().code() != other.dtype().code()) { |
| 46 | + LOG(FATAL) << "Parameter expression " << node << " with dtype " << node.dtype() |
| 47 | + << " cannot match to argument " << other << " with dtype " << other.dtype(); |
| 48 | + } else { |
| 49 | + ExprFunctor::VisitExpr(node, other); |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ |
| 54 | + void SymbolicMatcher::VisitExpr_(const OpName* op, const PrimExpr& other) { \ |
| 55 | + const auto* rhs = other.as<OpName>(); \ |
| 56 | + if (rhs) { \ |
| 57 | + VisitExpr(op->a, rhs->a); \ |
| 58 | + VisitExpr(op->b, rhs->b); \ |
| 59 | + } else { \ |
| 60 | + must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other); \ |
| 61 | + } \ |
| 62 | + } |
| 63 | + |
| 64 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); |
| 65 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode); |
| 66 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode); |
| 67 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode); |
| 68 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode); |
| 69 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode); |
| 70 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode); |
| 71 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode); |
| 72 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode); |
| 73 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode); |
| 74 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode); |
| 75 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode); |
| 76 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode); |
| 77 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode); |
| 78 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode); |
| 79 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode); |
| 80 | +TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode); |
| 81 | + |
| 82 | +void SymbolicMatcher::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { |
| 83 | + const auto* rhs = other.as<IntImmNode>(); |
| 84 | + if (!rhs || (op->value != rhs->value)) { |
| 85 | + LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) |
| 86 | + << " expected an integer argument with value " << op->value << ", " |
| 87 | + << "but was provided with the argument " << other; |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +void SymbolicMatcher::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { |
| 92 | + const auto* rhs = other.as<FloatImmNode>(); |
| 93 | + if (!rhs || (op->value != rhs->value)) { |
| 94 | + LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) |
| 95 | + << " expected an float argument with value " << op->value << ", " |
| 96 | + << "but was provided with the argument " << other; |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +void SymbolicMatcher::VisitExpr_(const CastNode* op, const PrimExpr& other) { |
| 101 | + const auto* rhs = other.as<CastNode>(); |
| 102 | + if (!rhs) { |
| 103 | + LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " expected an cast to " |
| 104 | + << op->dtype << " as the argument, " |
| 105 | + << "but was provided with the argument " << other; |
| 106 | + } |
| 107 | + VisitExpr(op->value, rhs->value); |
| 108 | +} |
| 109 | + |
| 110 | +void SymbolicMatcher::VisitExpr_(const VarNode* op, const PrimExpr& rhs) { |
| 111 | + auto lhs = ffi::GetRef<Var>(op); |
| 112 | + |
| 113 | + if (lhs.same_as(rhs)) { |
| 114 | + // Reference identity, no further checks needed. |
| 115 | + } else if (op->dtype.code() != rhs->dtype.code()) { |
| 116 | + LOG(FATAL) << "Parameter expression " << ffi::GetRef<PrimExpr>(op) << " with dtype " |
| 117 | + << op->dtype << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); |
| 118 | + } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { |
| 119 | + VisitExpr((*it).second, rhs); |
| 120 | + } else { |
| 121 | + var_remap_->Set(lhs, rhs); |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +void SymbolicMatcher::VisitExpr_(const SelectNode* op, const PrimExpr& other) { |
| 126 | + const auto* rhs = other.as<SelectNode>(); |
| 127 | + if (rhs) { |
| 128 | + VisitExpr(op->condition, rhs->condition); |
| 129 | + VisitExpr(op->true_value, rhs->true_value); |
| 130 | + VisitExpr(op->false_value, rhs->false_value); |
| 131 | + } else { |
| 132 | + must_prove_ = must_prove_ && (ffi::GetRef<PrimExpr>(op) == other); |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +} // namespace tir |
| 137 | +} // namespace tvm |
0 commit comments