@@ -70,17 +70,22 @@ struct TosaLevel {
7070 int32_t MAX_KERNEL = 0 ;
7171 int32_t MAX_STRIDE = 0 ;
7272 int32_t MAX_SCALE = 0 ;
73-
74- // @todo: MAX_LOG2_SIZE value and checks
73+ int32_t MAX_LOG2_SIZE = 0 ;
74+ int32_t MAX_NESTING = 0 ;
75+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
7576
7677 bool operator ==(const TosaLevel &rhs) {
7778 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
78- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
79+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
80+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
81+ MAX_NESTING == rhs.MAX_NESTING &&
82+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
7983 }
8084};
8185
82- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
83- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
86+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
87+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
88+ 63 , 256 , 256 };
8489
8590// ===----------------------------------------------------------------------===//
8691// TOSA Validation Pass.
@@ -147,107 +152,151 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147152 return true ;
148153 }
149154
150- bool levelCheckRank (Operation *op, const Value &v,
151- const std::string &checkDesc) {
155+ bool levelCheckListSize (Operation *op, int32_t v,
156+ const std::string &checkDesc) {
157+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
158+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
159+ << checkDesc;
160+ return false ;
161+ }
162+ return true ;
163+ }
164+
165+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
166+ const std::string &operandOrResult) {
152167 if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
153168 if (!type.hasRank ()) {
154169 op->emitOpError () << " failed level check: unranked tensor" ;
155170 return false ;
156171 }
157172 if (type.getRank () > tosaLevel.MAX_RANK ) {
158- op->emitOpError () << " failed level check: " << checkDesc;
173+ op->emitOpError () << " failed level check: " << operandOrResult
174+ << " rank(shape) <= MAX_RANK" ;
159175 return false ;
160176 }
177+
178+ const int64_t max_dim = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
179+ const int64_t max_size =
180+ (INT64_C (1 ) << (tosaLevel.MAX_LOG2_SIZE + 1 )) - 1 ;
181+
182+ auto shape = type.getShape ();
183+ bool has_dynamic = false ;
184+ for (auto dim : shape) {
185+ if (mlir::ShapedType::isDynamic (dim)) {
186+ has_dynamic = true ;
187+ continue ;
188+ }
189+ if (dim > max_dim) {
190+ op->emitOpError () << " failed level check: " << operandOrResult
191+ << " shape dimension <= (1<<MAX_LOG2_SIZE) - 1" ;
192+ return false ;
193+ }
194+ }
195+ if (!has_dynamic) {
196+ int64_t element_bits = type.getElementTypeBitWidth ();
197+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
198+ int64_t size = element_bytes * type.getNumElements ();
199+ if (size > max_size) {
200+ op->emitOpError ()
201+ << " failed level check: " << operandOrResult << " tensor size "
202+ << size << " (in bytes) <= "
203+ << " (1<<MAX_LOG2_SIZE+1) - 1, where max_size = " << max_size;
204+ return false ;
205+ }
206+ }
161207 }
162208 return true ;
163209 }
164210
165211 template <typename T>
166- bool levelCheckRanksFor (Operation *op) {
212+ bool levelCheckRanksAndSizesFor (Operation *op) {
167213 if (dyn_cast<T>(op)) {
168214 // level check ranks of all operands and results
169215 for (auto v : op->getOperands ()) {
170- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK " ))
216+ if (!levelCheckRankAndSizes (op, v, " operand" ))
171217 return false ;
172218 }
173219 for (auto v : op->getResults ()) {
174- if (!levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
220+ if (!levelCheckRankAndSizes (op, v, " result" ))
175221 return false ;
176222 }
177223 }
178224 return true ;
179225 }
180226
181- bool levelCheckRanks (Operation *op) {
182- #define CHECK_RANKS_FOR (tosaOp ) \
183- if (!levelCheckRanksFor <tosaOp##Op>(op)) \
227+ bool levelCheckRanksAndSizes (Operation *op) {
228+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
229+ if (!levelCheckRanksAndSizesFor <tosaOp##Op>(op)) \
184230 return false ;
185231
186232 // tensor operators:
187- CHECK_RANKS_FOR (ArgMax);
233+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
188234 // all activation functions:
189- CHECK_RANKS_FOR (Clamp);
190- CHECK_RANKS_FOR (Sigmoid);
191- CHECK_RANKS_FOR (Tanh);
235+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
236+ CHECK_RANKS_AND_SIZES_FOR (Erf);
237+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
238+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
192239 // all elementwise binary operators:
193- CHECK_RANKS_FOR (Add);
194- CHECK_RANKS_FOR (ArithmeticRightShift);
195- CHECK_RANKS_FOR (BitwiseAnd);
196- CHECK_RANKS_FOR (BitwiseOr);
197- CHECK_RANKS_FOR (BitwiseXor);
198- CHECK_RANKS_FOR (IntDiv);
199- CHECK_RANKS_FOR (LogicalAnd);
200- CHECK_RANKS_FOR (LogicalLeftShift);
201- CHECK_RANKS_FOR (LogicalRightShift);
202- CHECK_RANKS_FOR (LogicalOr);
203- CHECK_RANKS_FOR (LogicalXor);
204- CHECK_RANKS_FOR (Maximum);
205- CHECK_RANKS_FOR (Minimum);
206- CHECK_RANKS_FOR (Mul);
207- CHECK_RANKS_FOR (Pow);
208- CHECK_RANKS_FOR (Sub);
209- CHECK_RANKS_FOR (Table);
240+ CHECK_RANKS_AND_SIZES_FOR (Add);
241+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
242+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
243+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
244+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
245+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
246+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
247+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
248+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
249+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
250+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
251+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
252+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
253+ CHECK_RANKS_AND_SIZES_FOR (Mul);
254+ CHECK_RANKS_AND_SIZES_FOR (Pow);
255+ CHECK_RANKS_AND_SIZES_FOR (Sub);
256+ CHECK_RANKS_AND_SIZES_FOR (Table);
210257 // all elementwise unary operators:
211- CHECK_RANKS_FOR (Abs);
212- CHECK_RANKS_FOR (BitwiseNot);
213- CHECK_RANKS_FOR (Ceil);
214- CHECK_RANKS_FOR (Clz);
215- CHECK_RANKS_FOR (Exp);
216- CHECK_RANKS_FOR (Floor);
217- CHECK_RANKS_FOR (Log);
218- CHECK_RANKS_FOR (LogicalNot);
219- CHECK_RANKS_FOR (Negate);
220- CHECK_RANKS_FOR (Reciprocal);
221- CHECK_RANKS_FOR (Rsqrt);
258+ CHECK_RANKS_AND_SIZES_FOR (Abs);
259+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
260+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
261+ CHECK_RANKS_AND_SIZES_FOR (Clz);
262+ CHECK_RANKS_AND_SIZES_FOR (Cos);
263+ CHECK_RANKS_AND_SIZES_FOR (Exp);
264+ CHECK_RANKS_AND_SIZES_FOR (Floor);
265+ CHECK_RANKS_AND_SIZES_FOR (Log);
266+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
267+ CHECK_RANKS_AND_SIZES_FOR (Negate);
268+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
269+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
270+ CHECK_RANKS_AND_SIZES_FOR (Sin);
222271 // all elementwise ternary operators:
223- CHECK_RANKS_FOR (Select);
272+ CHECK_RANKS_AND_SIZES_FOR (Select);
224273 // all comparison operators:
225- CHECK_RANKS_FOR (Equal);
226- CHECK_RANKS_FOR (Greater);
227- CHECK_RANKS_FOR (GreaterEqual);
274+ CHECK_RANKS_AND_SIZES_FOR (Equal);
275+ CHECK_RANKS_AND_SIZES_FOR (Greater);
276+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
228277 // all reduction operators:
229- CHECK_RANKS_FOR (ReduceAll);
230- CHECK_RANKS_FOR (ReduceAny);
231- CHECK_RANKS_FOR (ReduceMax);
232- CHECK_RANKS_FOR (ReduceMin);
233- CHECK_RANKS_FOR (ReduceProd);
234- CHECK_RANKS_FOR (ReduceSum);
278+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
279+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
280+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
281+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
282+ CHECK_RANKS_AND_SIZES_FOR (ReduceProd);
283+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
235284 // all data layout operators:
236- CHECK_RANKS_FOR (Concat);
237- CHECK_RANKS_FOR (Pad);
238- CHECK_RANKS_FOR (Reshape);
239- CHECK_RANKS_FOR (Reverse);
240- CHECK_RANKS_FOR (Slice);
241- CHECK_RANKS_FOR (Tile);
242- CHECK_RANKS_FOR (Transpose);
285+ CHECK_RANKS_AND_SIZES_FOR (Concat);
286+ CHECK_RANKS_AND_SIZES_FOR (Pad);
287+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
288+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
289+ CHECK_RANKS_AND_SIZES_FOR (Slice);
290+ CHECK_RANKS_AND_SIZES_FOR (Tile);
291+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
243292 // all type conversion operators:
244- CHECK_RANKS_FOR (Cast);
245- CHECK_RANKS_FOR (Rescale);
293+ CHECK_RANKS_AND_SIZES_FOR (Cast);
294+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
246295 // all data nodes operators:
247- CHECK_RANKS_FOR (Const);
248- CHECK_RANKS_FOR (Identity);
296+ CHECK_RANKS_AND_SIZES_FOR (Const);
297+ CHECK_RANKS_AND_SIZES_FOR (Identity);
249298
250- #undef CHECK_RANKS_FOR
299+ #undef CHECK_RANKS_AND_SIZES_FOR
251300 return true ;
252301 }
253302
@@ -396,6 +445,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
396445 return true ;
397446 }
398447
448+ bool levelCheckListSize (Operation *op) {
449+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
450+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
451+ }
452+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
453+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
454+ !levelCheckListSize (op, custom.getOutputList ().size (),
455+ " output_list" )) {
456+ return false ;
457+ }
458+ }
459+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
460+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
461+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
462+ return false ;
463+ }
464+ }
465+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
466+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
467+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
468+ return false ;
469+ }
470+ }
471+ return true ;
472+ }
473+
399474 // configure profile and level values from pass options profileName and
400475 // levelName
401476 void configLevelAndProfile () {
@@ -449,7 +524,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
449524 return success ();
450525 }
451526
452- if (!levelCheckRanks (op)) {
527+ if (!levelCheckRanksAndSizes (op)) {
453528 return failure ();
454529 }
455530
@@ -465,6 +540,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
465540 return failure ();
466541 }
467542
543+ // level check MAX_TENSOR_LIST_SIZE
544+ if (!levelCheckListSize (op)) {
545+ return failure ();
546+ }
547+
468548 return success ();
469549}
470550
@@ -695,6 +775,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
695775}
696776
697777bool TosaValidation::isValidElementType (Type type) {
778+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
779+ type = quantType.getStorageType ();
780+
698781 if (isa<FloatType>(type)) {
699782 return type.isF32 () || type.isF16 () || type.isBF16 ();
700783 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
0 commit comments