@@ -61,17 +61,22 @@ struct TosaLevel {
6161 int32_t MAX_KERNEL = 0 ;
6262 int32_t MAX_STRIDE = 0 ;
6363 int32_t MAX_SCALE = 0 ;
64-
65- // @todo: MAX_LOG2_SIZE value and checks
64+ int32_t MAX_LOG2_SIZE = 0 ;
65+ int32_t MAX_NESTING = 0 ;
66+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
6667
6768 bool operator ==(const TosaLevel &rhs) {
6869 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
70+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72+ MAX_NESTING == rhs.MAX_NESTING &&
73+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
7074 }
7175};
7276
73- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
74- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
77+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
78+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
79+ 63 , 256 , 256 };
7580
7681// ===----------------------------------------------------------------------===//
7782// TOSA Validation Pass.
@@ -137,107 +142,188 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137142 return true ;
138143 }
139144
140- bool levelCheckRank (Operation *op, const Value &v,
141- const std::string &checkDesc) {
145+ bool levelCheckListSize (Operation *op, int32_t v,
146+ const std::string &checkDesc) {
147+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
148+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
149+ << checkDesc;
150+ return false ;
151+ }
152+ return true ;
153+ }
154+
155+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
156+ const std::string &operandOrResult,
157+ int32_t highest_rank) {
142158 if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
143159 if (!type.hasRank ()) {
144160 op->emitOpError () << " failed level check: unranked tensor" ;
145161 return false ;
146162 }
147- if (type.getRank () > tosaLevel.MAX_RANK ) {
148- op->emitOpError () << " failed level check: " << checkDesc;
163+ if (type.getRank () > highest_rank) {
164+ op->emitOpError () << " failed level check: " << operandOrResult
165+ << " rank(shape) <= MAX_RANK" ;
166+ return false ;
167+ }
168+
169+ auto shape = type.getShape ();
170+ for (auto dim : shape) {
171+ if (mlir::ShapedType::isDynamic (dim)) {
172+ op->emitOpError () << " failed level check: " << operandOrResult
173+ << " shape dimension cannot be dynamic" ;
174+ return false ;
175+ }
176+ }
177+
178+ int64_t element_bits = type.getElementTypeBitWidth ();
179+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
180+ int64_t size = element_bytes * type.getNumElements ();
181+
182+ // According to 1.11. Tensor Definitions of Tosa spec, the value of
183+ // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
184+ // defined in 1.7. Levels.
185+ // For each tensor, the number of tensor elements multiplied by the
186+ // element size in bytes must be representable as a tensor_size_t.
187+ const int64_t max_size = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
188+ if (size > max_size) {
189+ op->emitOpError ()
190+ << " failed level check: " << operandOrResult
191+ << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)" ;
149192 return false ;
150193 }
151194 }
152195 return true ;
153196 }
154197
155198 template <typename T>
156- bool levelCheckRanksFor (Operation *op ) {
157- if (dyn_cast<T>(op)) {
158- // level check ranks of all operands and results
159- for (auto v : op->getOperands ()) {
160- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK" ))
161- return false ;
162- }
163- for ( auto v : op-> getResults ()) {
164- if (! levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
165- return false ;
166- }
199+ bool levelCheckRanksAndSizesFor (T tosaOp ) {
200+ // level check ranks of all operands and results
201+ auto op = tosaOp. getOperation ();
202+ for (auto v : op->getOperands ()) {
203+ if (!levelCheckRankAndSizes (op, v, " operand" , tosaLevel. MAX_RANK ))
204+ return false ;
205+ }
206+
207+ for ( auto v : op-> getResults ()) {
208+ if (! levelCheckRankAndSizes (op, v, " result " , tosaLevel. MAX_RANK ))
209+ return false ;
167210 }
168211 return true ;
169212 }
170213
171- bool levelCheckRanks (Operation *op) {
172- #define CHECK_RANKS_FOR (tosaOp ) \
173- if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174- return false ;
214+ template <>
215+ bool levelCheckRanksAndSizesFor (tosa::ArgMaxOp tosaOp) {
216+ auto op = tosaOp.getOperation ();
217+ if (!levelCheckRankAndSizes (op, tosaOp.getInput (), " operand" ,
218+ tosaLevel.MAX_RANK ))
219+ return false ;
220+
221+ // rank(output) = rank(input) - 1
222+ if (!levelCheckRankAndSizes (op, tosaOp.getOutput (), " result" ,
223+ tosaLevel.MAX_RANK - 1 ))
224+ return false ;
225+
226+ return true ;
227+ }
228+
229+ template <>
230+ bool levelCheckRanksAndSizesFor (tosa::IfOp tosaOp) {
231+ auto op = tosaOp.getOperation ();
232+
233+ // Only the condition input has rank limitation.
234+ if (!levelCheckRankAndSizes (op, tosaOp.getCond (), " operand" ,
235+ tosaLevel.MAX_RANK ))
236+ return false ;
237+
238+ return true ;
239+ }
240+
241+ bool levelCheckRanksAndSizes (Operation *op) {
242+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
243+ if (isa<tosa::tosaOp##Op>(op)) \
244+ if (!levelCheckRanksAndSizesFor (cast<tosa::tosaOp##Op>(op))) \
245+ return false ;
246+
247+ #define CHECK_RANKS_AND_SIZES_SKIP (tosaOp ) \
248+ if (isa<tosa::tosaOp##Op>(op)) \
249+ return true ;
175250
176251 // tensor operators:
177- CHECK_RANKS_FOR (ArgMax);
252+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
178253 // all activation functions:
179- CHECK_RANKS_FOR (Clamp);
180- CHECK_RANKS_FOR (Sigmoid);
181- CHECK_RANKS_FOR (Tanh);
254+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
255+ CHECK_RANKS_AND_SIZES_FOR (Erf);
256+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
257+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
182258 // all elementwise binary operators:
183- CHECK_RANKS_FOR (Add);
184- CHECK_RANKS_FOR (ArithmeticRightShift);
185- CHECK_RANKS_FOR (BitwiseAnd);
186- CHECK_RANKS_FOR (BitwiseOr);
187- CHECK_RANKS_FOR (BitwiseXor);
188- CHECK_RANKS_FOR (IntDiv);
189- CHECK_RANKS_FOR (LogicalAnd);
190- CHECK_RANKS_FOR (LogicalLeftShift);
191- CHECK_RANKS_FOR (LogicalRightShift);
192- CHECK_RANKS_FOR (LogicalOr);
193- CHECK_RANKS_FOR (LogicalXor);
194- CHECK_RANKS_FOR (Maximum);
195- CHECK_RANKS_FOR (Minimum);
196- CHECK_RANKS_FOR (Mul);
197- CHECK_RANKS_FOR (Pow);
198- CHECK_RANKS_FOR (Sub);
199- CHECK_RANKS_FOR (Table);
259+ CHECK_RANKS_AND_SIZES_FOR (Add);
260+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
261+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
262+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
263+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
264+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
265+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
266+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
267+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
268+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
269+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
270+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
271+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
272+ CHECK_RANKS_AND_SIZES_FOR (Mul);
273+ CHECK_RANKS_AND_SIZES_FOR (Pow);
274+ CHECK_RANKS_AND_SIZES_FOR (Sub);
275+ CHECK_RANKS_AND_SIZES_FOR (Table);
200276 // all elementwise unary operators:
201- CHECK_RANKS_FOR (Abs);
202- CHECK_RANKS_FOR (BitwiseNot);
203- CHECK_RANKS_FOR (Ceil);
204- CHECK_RANKS_FOR (Clz);
205- CHECK_RANKS_FOR (Exp);
206- CHECK_RANKS_FOR (Floor);
207- CHECK_RANKS_FOR (Log);
208- CHECK_RANKS_FOR (LogicalNot);
209- CHECK_RANKS_FOR (Negate);
210- CHECK_RANKS_FOR (Reciprocal);
211- CHECK_RANKS_FOR (Rsqrt);
277+ CHECK_RANKS_AND_SIZES_FOR (Abs);
278+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
279+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
280+ CHECK_RANKS_AND_SIZES_FOR (Clz);
281+ CHECK_RANKS_AND_SIZES_FOR (Cos);
282+ CHECK_RANKS_AND_SIZES_FOR (Exp);
283+ CHECK_RANKS_AND_SIZES_FOR (Floor);
284+ CHECK_RANKS_AND_SIZES_FOR (Log);
285+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
286+ CHECK_RANKS_AND_SIZES_FOR (Negate);
287+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
288+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
289+ CHECK_RANKS_AND_SIZES_FOR (Sin);
212290 // all elementwise ternary operators:
213- CHECK_RANKS_FOR (Select);
291+ CHECK_RANKS_AND_SIZES_FOR (Select);
214292 // all comparison operators:
215- CHECK_RANKS_FOR (Equal);
216- CHECK_RANKS_FOR (Greater);
217- CHECK_RANKS_FOR (GreaterEqual);
293+ CHECK_RANKS_AND_SIZES_FOR (Equal);
294+ CHECK_RANKS_AND_SIZES_FOR (Greater);
295+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
218296 // all reduction operators:
219- CHECK_RANKS_FOR (ReduceAll);
220- CHECK_RANKS_FOR (ReduceAny);
221- CHECK_RANKS_FOR (ReduceMax);
222- CHECK_RANKS_FOR (ReduceMin);
223- CHECK_RANKS_FOR (ReduceProduct);
224- CHECK_RANKS_FOR (ReduceSum);
297+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
298+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
299+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
300+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
301+ CHECK_RANKS_AND_SIZES_FOR (ReduceProduct);
302+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
225303 // all data layout operators:
226- CHECK_RANKS_FOR (Concat);
227- CHECK_RANKS_FOR (Pad);
228- CHECK_RANKS_FOR (Reshape);
229- CHECK_RANKS_FOR (Reverse);
230- CHECK_RANKS_FOR (Slice);
231- CHECK_RANKS_FOR (Tile);
232- CHECK_RANKS_FOR (Transpose);
304+ CHECK_RANKS_AND_SIZES_FOR (Concat);
305+ CHECK_RANKS_AND_SIZES_FOR (Pad);
306+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
307+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
308+ CHECK_RANKS_AND_SIZES_FOR (Slice);
309+ CHECK_RANKS_AND_SIZES_FOR (Tile);
310+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
233311 // all type conversion operators:
234- CHECK_RANKS_FOR (Cast);
235- CHECK_RANKS_FOR (Rescale);
312+ CHECK_RANKS_AND_SIZES_FOR (Cast);
313+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
314+ // control flow operators:
315+ CHECK_RANKS_AND_SIZES_FOR (If);
236316 // all data nodes operators:
237- CHECK_RANKS_FOR (Const);
238- CHECK_RANKS_FOR (Identity);
317+ CHECK_RANKS_AND_SIZES_FOR (Const);
318+ CHECK_RANKS_AND_SIZES_FOR (Identity);
319+
320+ // The following operators do not have level rank and size constraint.
321+ CHECK_RANKS_AND_SIZES_SKIP (Yield);
322+ CHECK_RANKS_AND_SIZES_SKIP (Custom);
323+ CHECK_RANKS_AND_SIZES_SKIP (While);
239324
240- #undef CHECK_RANKS_FOR
325+ #undef CHECK_RANKS_AND_SIZES_FOR
326+ #undef CHECK_RANKS_AND_SIZES_SKIP
241327 return true ;
242328 }
243329
@@ -386,6 +472,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386472 return true ;
387473 }
388474
475+ bool levelCheckListSize (Operation *op) {
476+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
477+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
478+ }
479+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
480+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
481+ !levelCheckListSize (op, custom.getOutputList ().size (),
482+ " output_list" )) {
483+ return false ;
484+ }
485+ }
486+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
487+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
488+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
489+ return false ;
490+ }
491+ }
492+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
493+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
494+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
495+ return false ;
496+ }
497+ }
498+ return true ;
499+ }
500+
389501 // configure profile and level values from pass options profileName and
390502 // levelName
391503 void configLevelAndProfile () {
@@ -439,7 +551,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439551 return success ();
440552 }
441553
442- if (!levelCheckRanks (op)) {
554+ if (!levelCheckRanksAndSizes (op)) {
443555 return failure ();
444556 }
445557
@@ -455,6 +567,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455567 return failure ();
456568 }
457569
570+ // level check MAX_TENSOR_LIST_SIZE
571+ if (!levelCheckListSize (op)) {
572+ return failure ();
573+ }
574+
458575 return success ();
459576}
460577
@@ -685,6 +802,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
685802}
686803
687804bool TosaValidation::isValidElementType (Type type) {
805+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
806+ type = quantType.getStorageType ();
807+
688808 if (isa<FloatType>(type)) {
689809 return type.isF32 () || type.isF16 () || type.isBF16 ();
690810 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
0 commit comments