@@ -61,17 +61,22 @@ struct TosaLevel {
6161 int32_t MAX_KERNEL = 0 ;
6262 int32_t MAX_STRIDE = 0 ;
6363 int32_t MAX_SCALE = 0 ;
64-
65- // @todo: MAX_LOG2_SIZE value and checks
64+ int32_t MAX_LOG2_SIZE = 0 ;
65+ int32_t MAX_NESTING = 0 ;
66+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
6667
6768 bool operator ==(const TosaLevel &rhs) {
6869 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
70+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72+ MAX_NESTING == rhs.MAX_NESTING &&
73+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
7074 }
7175};
7276
73- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
74- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
77+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
78+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
79+ 63 , 256 , 256 };
7580
7681// ===----------------------------------------------------------------------===//
7782// TOSA Validation Pass.
@@ -111,133 +116,212 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
111116 constCheckers.emplace_back (checkConstantOperandPad);
112117 }
113118
114- bool levelCheckKernel (Operation *op, int32_t v,
115- const std::string &checkDesc) {
119+ bool levelCheckKernel (Operation *op, int32_t v, const StringRef checkDesc) {
116120 if (v > tosaLevel.MAX_KERNEL ) {
117121 op->emitOpError () << " failed level check: " << checkDesc;
118122 return false ;
119123 }
120124 return true ;
121125 }
122126
123- bool levelCheckStride (Operation *op, int32_t v,
124- const std::string &checkDesc) {
127+ bool levelCheckStride (Operation *op, int32_t v, const StringRef checkDesc) {
125128 if (v > tosaLevel.MAX_STRIDE ) {
126129 op->emitOpError () << " failed level check: " << checkDesc;
127130 return false ;
128131 }
129132 return true ;
130133 }
131134
132- bool levelCheckScale (Operation *op, int32_t v, const std::string & checkDesc) {
135+ bool levelCheckScale (Operation *op, int32_t v, const StringRef checkDesc) {
133136 if (v > tosaLevel.MAX_SCALE ) {
134137 op->emitOpError () << " failed level check: " << checkDesc;
135138 return false ;
136139 }
137140 return true ;
138141 }
139142
140- bool levelCheckRank (Operation *op, const Value &v,
141- const std::string &checkDesc) {
143+ bool levelCheckListSize (Operation *op, int32_t v, const StringRef checkDesc) {
144+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
145+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
146+ << checkDesc;
147+ return false ;
148+ }
149+ return true ;
150+ }
151+
152+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
153+ const StringRef operandOrResult,
154+ int32_t highest_rank) {
142155 if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
143156 if (!type.hasRank ()) {
144157 op->emitOpError () << " failed level check: unranked tensor" ;
145158 return false ;
146159 }
147- if (type.getRank () > tosaLevel.MAX_RANK ) {
148- op->emitOpError () << " failed level check: " << checkDesc;
160+ if (type.getRank () > highest_rank) {
161+ op->emitOpError () << " failed level check: " << operandOrResult
162+ << " rank(shape) <= MAX_RANK" ;
163+ return false ;
164+ }
165+
166+ auto shape = type.getShape ();
167+ for (auto dim : shape) {
168+ if (mlir::ShapedType::isDynamic (dim)) {
169+ op->emitOpError () << " failed level check: " << operandOrResult
170+ << " shape dimension cannot be dynamic" ;
171+ return false ;
172+ }
173+ }
174+
175+ int64_t element_bits = type.getElementTypeBitWidth ();
176+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
177+ int64_t size = element_bytes * type.getNumElements ();
178+
179+ // According to 1.11. Tensor Definitions of Tosa spec, the value of
180+ // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
181+ // defined in 1.7. Levels.
182+ // For each tensor, the number of tensor elements multiplied by the
183+ // element size in bytes must be representable as a tensor_size_t.
184+ const int64_t max_size = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
185+ if (size > max_size) {
186+ op->emitOpError ()
187+ << " failed level check: " << operandOrResult
188+ << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)" ;
149189 return false ;
150190 }
151191 }
152192 return true ;
153193 }
154194
155195 template <typename T>
156- bool levelCheckRanksFor (Operation *op ) {
157- if (dyn_cast<T>(op)) {
158- // level check ranks of all operands and results
159- for (auto v : op->getOperands ()) {
160- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK" ))
161- return false ;
162- }
163- for ( auto v : op-> getResults ()) {
164- if (! levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
165- return false ;
166- }
196+ bool levelCheckRanksAndSizesFor (T tosaOp ) {
197+ // level check ranks of all operands and results
198+ auto op = tosaOp. getOperation ();
199+ for (auto v : op->getOperands ()) {
200+ if (!levelCheckRankAndSizes (op, v, " operand" , tosaLevel. MAX_RANK ))
201+ return false ;
202+ }
203+
204+ for ( auto v : op-> getResults ()) {
205+ if (! levelCheckRankAndSizes (op, v, " result " , tosaLevel. MAX_RANK ))
206+ return false ;
167207 }
168208 return true ;
169209 }
170210
171- bool levelCheckRanks (Operation *op) {
172- #define CHECK_RANKS_FOR (tosaOp ) \
173- if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174- return false ;
211+ template <>
212+ bool levelCheckRanksAndSizesFor (tosa::ArgMaxOp tosaOp) {
213+ auto op = tosaOp.getOperation ();
214+ if (!levelCheckRankAndSizes (op, tosaOp.getInput (), " operand" ,
215+ tosaLevel.MAX_RANK ))
216+ return false ;
217+
218+ // rank(output) = rank(input) - 1
219+ if (!levelCheckRankAndSizes (op, tosaOp.getOutput (), " result" ,
220+ tosaLevel.MAX_RANK - 1 ))
221+ return false ;
222+
223+ return true ;
224+ }
225+
226+ template <>
227+ bool levelCheckRanksAndSizesFor (tosa::IfOp tosaOp) {
228+ auto op = tosaOp.getOperation ();
229+
230+ // Only the condition input has rank limitation.
231+ if (!levelCheckRankAndSizes (op, tosaOp.getCond (), " operand" ,
232+ tosaLevel.MAX_RANK ))
233+ return false ;
234+
235+ return true ;
236+ }
237+
238+ bool levelCheckRanksAndSizes (Operation *op) {
239+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
240+ if (isa<tosa::tosaOp##Op>(op)) \
241+ if (!levelCheckRanksAndSizesFor (cast<tosa::tosaOp##Op>(op))) \
242+ return false ;
243+
244+ #define CHECK_RANKS_AND_SIZES_SKIP (tosaOp ) \
245+ if (isa<tosa::tosaOp##Op>(op)) \
246+ return true ;
175247
176248 // tensor operators:
177- CHECK_RANKS_FOR (ArgMax);
249+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
178250 // all activation functions:
179- CHECK_RANKS_FOR (Clamp);
180- CHECK_RANKS_FOR (Sigmoid);
181- CHECK_RANKS_FOR (Tanh);
251+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
252+ CHECK_RANKS_AND_SIZES_FOR (Erf);
253+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
254+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
182255 // all elementwise binary operators:
183- CHECK_RANKS_FOR (Add);
184- CHECK_RANKS_FOR (ArithmeticRightShift);
185- CHECK_RANKS_FOR (BitwiseAnd);
186- CHECK_RANKS_FOR (BitwiseOr);
187- CHECK_RANKS_FOR (BitwiseXor);
188- CHECK_RANKS_FOR (IntDiv);
189- CHECK_RANKS_FOR (LogicalAnd);
190- CHECK_RANKS_FOR (LogicalLeftShift);
191- CHECK_RANKS_FOR (LogicalRightShift);
192- CHECK_RANKS_FOR (LogicalOr);
193- CHECK_RANKS_FOR (LogicalXor);
194- CHECK_RANKS_FOR (Maximum);
195- CHECK_RANKS_FOR (Minimum);
196- CHECK_RANKS_FOR (Mul);
197- CHECK_RANKS_FOR (Pow);
198- CHECK_RANKS_FOR (Sub);
199- CHECK_RANKS_FOR (Table);
256+ CHECK_RANKS_AND_SIZES_FOR (Add);
257+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
258+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
259+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
260+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
261+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
262+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
263+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
264+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
265+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
266+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
267+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
268+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
269+ CHECK_RANKS_AND_SIZES_FOR (Mul);
270+ CHECK_RANKS_AND_SIZES_FOR (Pow);
271+ CHECK_RANKS_AND_SIZES_FOR (Sub);
272+ CHECK_RANKS_AND_SIZES_FOR (Table);
200273 // all elementwise unary operators:
201- CHECK_RANKS_FOR (Abs);
202- CHECK_RANKS_FOR (BitwiseNot);
203- CHECK_RANKS_FOR (Ceil);
204- CHECK_RANKS_FOR (Clz);
205- CHECK_RANKS_FOR (Exp);
206- CHECK_RANKS_FOR (Floor);
207- CHECK_RANKS_FOR (Log);
208- CHECK_RANKS_FOR (LogicalNot);
209- CHECK_RANKS_FOR (Negate);
210- CHECK_RANKS_FOR (Reciprocal);
211- CHECK_RANKS_FOR (Rsqrt);
274+ CHECK_RANKS_AND_SIZES_FOR (Abs);
275+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
276+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
277+ CHECK_RANKS_AND_SIZES_FOR (Clz);
278+ CHECK_RANKS_AND_SIZES_FOR (Cos);
279+ CHECK_RANKS_AND_SIZES_FOR (Exp);
280+ CHECK_RANKS_AND_SIZES_FOR (Floor);
281+ CHECK_RANKS_AND_SIZES_FOR (Log);
282+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
283+ CHECK_RANKS_AND_SIZES_FOR (Negate);
284+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
285+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
286+ CHECK_RANKS_AND_SIZES_FOR (Sin);
212287 // all elementwise ternary operators:
213- CHECK_RANKS_FOR (Select);
288+ CHECK_RANKS_AND_SIZES_FOR (Select);
214289 // all comparison operators:
215- CHECK_RANKS_FOR (Equal);
216- CHECK_RANKS_FOR (Greater);
217- CHECK_RANKS_FOR (GreaterEqual);
290+ CHECK_RANKS_AND_SIZES_FOR (Equal);
291+ CHECK_RANKS_AND_SIZES_FOR (Greater);
292+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
218293 // all reduction operators:
219- CHECK_RANKS_FOR (ReduceAll);
220- CHECK_RANKS_FOR (ReduceAny);
221- CHECK_RANKS_FOR (ReduceMax);
222- CHECK_RANKS_FOR (ReduceMin);
223- CHECK_RANKS_FOR (ReduceProduct);
224- CHECK_RANKS_FOR (ReduceSum);
294+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
295+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
296+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
297+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
298+ CHECK_RANKS_AND_SIZES_FOR (ReduceProduct);
299+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
225300 // all data layout operators:
226- CHECK_RANKS_FOR (Concat);
227- CHECK_RANKS_FOR (Pad);
228- CHECK_RANKS_FOR (Reshape);
229- CHECK_RANKS_FOR (Reverse);
230- CHECK_RANKS_FOR (Slice);
231- CHECK_RANKS_FOR (Tile);
232- CHECK_RANKS_FOR (Transpose);
301+ CHECK_RANKS_AND_SIZES_FOR (Concat);
302+ CHECK_RANKS_AND_SIZES_FOR (Pad);
303+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
304+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
305+ CHECK_RANKS_AND_SIZES_FOR (Slice);
306+ CHECK_RANKS_AND_SIZES_FOR (Tile);
307+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
233308 // all type conversion operators:
234- CHECK_RANKS_FOR (Cast);
235- CHECK_RANKS_FOR (Rescale);
309+ CHECK_RANKS_AND_SIZES_FOR (Cast);
310+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
311+ // control flow operators:
312+ CHECK_RANKS_AND_SIZES_FOR (If);
236313 // all data nodes operators:
237- CHECK_RANKS_FOR (Const);
238- CHECK_RANKS_FOR (Identity);
314+ CHECK_RANKS_AND_SIZES_FOR (Const);
315+ CHECK_RANKS_AND_SIZES_FOR (Identity);
239316
240- #undef CHECK_RANKS_FOR
317+ // The following operators do not have level rank and size constraint.
318+ CHECK_RANKS_AND_SIZES_SKIP (Resize);
319+ CHECK_RANKS_AND_SIZES_SKIP (Yield);
320+ CHECK_RANKS_AND_SIZES_SKIP (Custom);
321+ CHECK_RANKS_AND_SIZES_SKIP (While);
322+
323+ #undef CHECK_RANKS_AND_SIZES_FOR
324+ #undef CHECK_RANKS_AND_SIZES_SKIP
241325 return true ;
242326 }
243327
@@ -386,6 +470,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386470 return true ;
387471 }
388472
473+ bool levelCheckListSize (Operation *op) {
474+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
475+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
476+ }
477+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
478+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
479+ !levelCheckListSize (op, custom.getOutputList ().size (),
480+ " output_list" )) {
481+ return false ;
482+ }
483+ }
484+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
485+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
486+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
487+ return false ;
488+ }
489+ }
490+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
491+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
492+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
493+ return false ;
494+ }
495+ }
496+ return true ;
497+ }
498+
389499 // configure profile and level values from pass options profileName and
390500 // levelName
391501 void configLevelAndProfile () {
@@ -439,7 +549,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439549 return success ();
440550 }
441551
442- if (!levelCheckRanks (op)) {
552+ if (!levelCheckRanksAndSizes (op)) {
443553 return failure ();
444554 }
445555
@@ -455,6 +565,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455565 return failure ();
456566 }
457567
568+ // level check MAX_TENSOR_LIST_SIZE
569+ if (!levelCheckListSize (op)) {
570+ return failure ();
571+ }
572+
458573 return success ();
459574}
460575
0 commit comments