@@ -41,17 +41,91 @@ using namespace mlir::tosa;
4141
4242namespace {
4343
44- static LogicalResult checkConstantOperandPad (Operation *op) {
44+ static LogicalResult
45+ checkConstantOperands (Operation *op, ArrayRef<unsigned int > operandIndices) {
46+ for (const auto index : operandIndices) {
47+ Attribute attr;
48+ if (!matchPattern (op->getOperand (index), m_Constant (&attr))) {
49+ return op->emitOpError (" expected compile time resolvable constant, but "
50+ " got variable value for operand #" )
51+ << index;
52+ }
53+ }
54+ return success ();
55+ }
56+
57+ static LogicalResult checkConstantOperandMul (Operation *op,
58+ const TargetEnv &env) {
59+ if (!env.allows (Extension::dynamic) && isa<tosa::MulOp>(op)) {
60+ // Check 'shift'
61+ return checkConstantOperands (op, {2 });
62+ }
63+ return success ();
64+ }
65+
66+ static LogicalResult checkConstantOperandTable (Operation *op,
67+ const TargetEnv &env) {
68+ if (!env.allows (Extension::dynamic) && isa<tosa::TableOp>(op)) {
69+ // Check 'table'
70+ return checkConstantOperands (op, {1 });
71+ }
72+ return success ();
73+ }
74+
75+ static LogicalResult checkConstantOperandPad (Operation *op,
76+ const TargetEnv &env) {
4577 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
46- DenseElementsAttr paddings;
47- if (!matchPattern (padOp.getPadding (), m_Constant (&paddings)))
48- return op->emitOpError (" padding of pad is not constant" );
78+ // Assume this op is zero-padding if padConst is not presented
79+ if (!env.allows (Extension::dynamic) && padOp.getPadConst ())
80+ // Check 'pad_const'
81+ // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82+ return checkConstantOperands (op, {2 });
83+ }
84+ return success ();
85+ }
86+
87+ static LogicalResult checkConstantOperandRescale (Operation *op,
88+ const TargetEnv &env) {
89+ if (!env.allows (Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90+ // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91+ return checkConstantOperands (op, {1 , 2 , 3 , 4 });
92+ }
93+ return success ();
94+ }
95+
96+ template <typename T>
97+ static LogicalResult checkConstantOperandConvOps (Operation *op,
98+ const TargetEnv &env) {
99+ if (!env.allows (Extension::dynamic) && isa<T>(op)) {
100+ // Check 'input_zp' and 'weight_zp'
101+ return checkConstantOperands (op, {3 , 4 });
102+ }
103+ return success ();
104+ }
105+
106+ static LogicalResult checkConstantOperandMatMul (Operation *op,
107+ const TargetEnv &env) {
108+ if (!env.allows (Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109+ // Check 'A_zp' and 'B_zp'
110+ return checkConstantOperands (op, {2 , 3 });
111+ }
112+ return success ();
113+ }
114+
115+ static LogicalResult checkConstantOperandAvgPool2d (Operation *op,
116+ const TargetEnv &env) {
117+ if (!env.allows (Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118+ // Check 'input_zp' and 'output_zp'
119+ return checkConstantOperands (op, {1 , 2 });
120+ }
121+ return success ();
122+ }
49123
50- DenseElementsAttr padConst;
51- // Assume this op is zero-padding if padConst is not presented.
52- if (padOp. getPadConst ( ) &&
53- ! matchPattern (padOp. getPadConst (), m_Constant (&padConst)))
54- return op-> emitOpError ( " pad_const of pad is not constant " );
124+ static LogicalResult checkConstantOperandNegate (Operation *op,
125+ const TargetEnv &env) {
126+ if (!env. allows (Extension::dynamic ) && isa<tosa::NegateOp>(op)) {
127+ // Check 'input1_zp' and 'output_zp'
128+ return checkConstantOperands (op, { 1 , 2 } );
55129 }
56130 return success ();
57131}
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
97171
98172 LogicalResult applyConstantOperandCheck (Operation *op) {
99173 for (auto &checker : constCheckers) {
100- if (failed (checker (op)))
174+ if (failed (checker (op, targetEnv )))
101175 return failure ();
102176 }
103177 return success ();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
114188
115189private:
116190 void populateConstantOperandChecks () {
191+ constCheckers.emplace_back (checkConstantOperandMul);
192+ constCheckers.emplace_back (checkConstantOperandTable);
117193 constCheckers.emplace_back (checkConstantOperandPad);
194+ constCheckers.emplace_back (checkConstantOperandRescale);
195+ constCheckers.emplace_back (checkConstantOperandConvOps<tosa::Conv2DOp>);
196+ constCheckers.emplace_back (checkConstantOperandConvOps<tosa::Conv3DOp>);
197+ constCheckers.emplace_back (
198+ checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
199+ constCheckers.emplace_back (
200+ checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
201+ constCheckers.emplace_back (checkConstantOperandMatMul);
202+ constCheckers.emplace_back (checkConstantOperandAvgPool2d);
203+ constCheckers.emplace_back (checkConstantOperandNegate);
118204 }
119205
120206 bool levelCheckKernel (Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
436522 llvm::errs () << " unknown TOSA extension name passed in: " << ext
437523 << " , supported extension are int16, int4, bf16, "
438524 << " fp8e4m3, fp8e5m2, fft, variable, controlflow, "
439- << " doubleround and inexactround \n " ;
525+ << " doubleround, inexactround and dynamic \n " ;
440526 return signalPassFailure ();
441527 }
442528 }
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
447533 bool CheckVariableReadOrWrite (Operation *op);
448534 bool isValidElementType (Type type);
449535
450- SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
536+ SmallVector<
537+ std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
538+ constCheckers;
451539 TosaLevel tosaLevel;
452540 DenseMap<StringAttr, mlir::Type> variablesMap;
453541 TosaProfileCompliance profileComp;
0 commit comments