2727#include " mlir/Pass/Pass.h"
2828#include " mlir/Transforms/DialectConversion.h"
2929#include " llvm/ADT/StringExtras.h"
30+ #include " llvm/Support/FormatVariadic.h"
3031
3132namespace mlir {
3233namespace tosa {
@@ -189,33 +190,40 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
189190 constCheckers.emplace_back (checkConstantOperandSilceShape);
190191 }
191192
192- LogicalResult levelCheckKernel (Operation *op, int32_t v,
193- const StringRef checkDesc) {
194- if (v > targetEnv.getLevel ().MAX_KERNEL )
195- return op->emitOpError () << " failed level check: " << checkDesc;
193+ LogicalResult levelCheck (Operation *op, const int32_t calculatedValue,
194+ const int32_t maxLevel, const StringRef inputName,
195+ const StringRef levelName) {
196+ if (calculatedValue > maxLevel)
197+ return op->emitOpError ()
198+ << " failed level check: " << inputName << " <= " << levelName
199+ << " (" << maxLevel << " ), got " << calculatedValue;
196200 return success ();
197201 }
198202
203+ LogicalResult levelCheckKernel (Operation *op, int32_t v,
204+ const StringRef inputName) {
205+ return levelCheck (op, v, targetEnv.getLevel ().MAX_KERNEL , inputName,
206+ " MAX_KERNEL" );
207+ }
208+
199209 LogicalResult levelCheckStride (Operation *op, int32_t v,
200- const StringRef checkDesc) {
201- if (v > targetEnv.getLevel ().MAX_STRIDE )
202- return op->emitOpError () << " failed level check: " << checkDesc;
203- return success ();
210+ const StringRef inputName) {
211+ return levelCheck (op, v, targetEnv.getLevel ().MAX_STRIDE , inputName,
212+ " MAX_STRIDE" );
204213 }
205214
206215 LogicalResult levelCheckScale (Operation *op, int32_t v,
207- const StringRef checkDesc) {
208- if (v > targetEnv.getLevel ().MAX_SCALE )
209- return op->emitOpError () << " failed level check: " << checkDesc;
210- return success ();
216+ const StringRef inputName) {
217+ return levelCheck (op, v, targetEnv.getLevel ().MAX_SCALE , inputName,
218+ " MAX_SCALE" );
211219 }
212220
213221 LogicalResult levelCheckListSize (Operation *op, int32_t v,
214- const StringRef checkDesc ) {
215- if (v > targetEnv. getLevel (). MAX_TENSOR_LIST_SIZE )
216- return op-> emitOpError ()
217- << " failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
218- return success ( );
222+ const StringRef inputName ) {
223+ const std::string inputDesc =
224+ llvm::formatv ( " length(tensor_list_shape({0})) " , inputName);
225+ return levelCheck (op, v, targetEnv. getLevel (). MAX_TENSOR_LIST_SIZE ,
226+ inputDesc, " MAX_TENSOR_LIST_SIZE " );
219227 }
220228
221229 // Perform the Level Rank check on the tensor type.
@@ -317,17 +325,17 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
317325 LogicalResult levelCheckPool (Operation *op) {
318326 if (auto poolOp = dyn_cast<T>(op)) {
319327 for (auto k : poolOp.getKernel ()) {
320- if (failed (levelCheckKernel (op, k, " kernel <= MAX_KERNEL " ))) {
328+ if (failed (levelCheckKernel (op, k, " kernel" ))) {
321329 return failure ();
322330 }
323331 }
324332 for (auto s : poolOp.getStride ()) {
325- if (failed (levelCheckStride (op, s, " stride <= MAX_STRIDE " ))) {
333+ if (failed (levelCheckStride (op, s, " stride" ))) {
326334 return failure ();
327335 }
328336 }
329337 for (auto p : poolOp.getPad ()) {
330- if (failed (levelCheckKernel (op, p, " pad <= MAX_KERNEL " ))) {
338+ if (failed (levelCheckKernel (op, p, " pad" ))) {
331339 return failure ();
332340 }
333341 }
@@ -341,17 +349,17 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
341349 if (auto convOp = dyn_cast<T>(op)) {
342350
343351 for (auto k : convOp.getDilation ()) {
344- if (failed (levelCheckKernel (op, k, " dilation <= MAX_KERNEL " ))) {
352+ if (failed (levelCheckKernel (op, k, " dilation" ))) {
345353 return failure ();
346354 }
347355 }
348356 for (auto p : convOp.getPad ()) {
349- if (failed (levelCheckKernel (op, p, " pad <= MAX_KERNEL " ))) {
357+ if (failed (levelCheckKernel (op, p, " pad" ))) {
350358 return failure ();
351359 }
352360 }
353361 for (auto s : convOp.getStride ()) {
354- if (failed (levelCheckStride (op, s, " stride <= MAX_STRIDE " ))) {
362+ if (failed (levelCheckStride (op, s, " stride" ))) {
355363 return failure ();
356364 }
357365 }
@@ -363,27 +371,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
363371 assert (shape.size () == 4 );
364372 assert (dilation.size () == 2 );
365373 if (failed (levelCheckKernel (op, dilation[0 ] * shape[1 ],
366- " dilation_y * KH <= MAX_KERNEL) " )) ||
374+ " dilation_y * KH" )) ||
367375 failed (levelCheckKernel (op, dilation[1 ] * shape[2 ],
368- " dilation_x * KW <= MAX_KERNEL) " )))
376+ " dilation_x * KW" )))
369377 return failure ();
370378 } else if (isa<tosa::Conv3DOp>(op)) {
371379 assert (shape.size () == 5 );
372380 assert (dilation.size () == 3 );
373381 if (failed (levelCheckKernel (op, dilation[0 ] * shape[1 ],
374- " dilation_d * KD <= MAX_KERNEL) " )) ||
382+ " dilation_d * KD" )) ||
375383 failed (levelCheckKernel (op, dilation[1 ] * shape[2 ],
376- " dilation_y * KH <= MAX_KERNEL) " )) ||
384+ " dilation_y * KH" )) ||
377385 failed (levelCheckKernel (op, dilation[2 ] * shape[3 ],
378- " dilation_x * KW <= MAX_KERNEL) " )))
386+ " dilation_x * KW" )))
379387 return failure ();
380388 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
381389 assert (shape.size () == 4 );
382390 assert (dilation.size () == 2 );
383391 if (failed (levelCheckKernel (op, dilation[0 ] * shape[0 ],
384- " dilation_y * KH <= MAX_KERNEL) " )) ||
392+ " dilation_y * KH" )) ||
385393 failed (levelCheckKernel (op, dilation[1 ] * shape[1 ],
386- " dilation_x * KW <= MAX_KERNEL) " )))
394+ " dilation_x * KW" )))
387395 return failure ();
388396 }
389397 }
@@ -445,8 +453,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
445453 if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
446454 auto shape = type.getShape ();
447455 assert (shape.size () == 3 );
448- if (failed (levelCheckKernel (op, shape[1 ], " H <= MAX_KERNEL " )) ||
449- failed (levelCheckKernel (op, shape[2 ], " W <= MAX_KERNEL " ))) {
456+ if (failed (levelCheckKernel (op, shape[1 ], " H" )) ||
457+ failed (levelCheckKernel (op, shape[2 ], " W" ))) {
450458 return failure ();
451459 }
452460 }
@@ -463,18 +471,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
463471 auto shape = filterType.getShape ();
464472 assert (shape.size () == 4 );
465473 // level check kernel sizes for kH and KW
466- if (failed (levelCheckKernel (op, shape[1 ], " KH <= MAX_KERNEL " )) ||
467- failed (levelCheckKernel (op, shape[2 ], " KW <= MAX_KERNEL " ))) {
474+ if (failed (levelCheckKernel (op, shape[1 ], " KH" )) ||
475+ failed (levelCheckKernel (op, shape[2 ], " KW" ))) {
468476 return failure ();
469477 }
470478 }
471479 for (auto p : transpose.getOutPad ()) {
472- if (failed (levelCheckKernel (op, p, " pad <= MAX_KERNEL " ))) {
480+ if (failed (levelCheckKernel (op, p, " pad" ))) {
473481 return failure ();
474482 }
475483 }
476484 for (auto s : transpose.getStride ()) {
477- if (failed (levelCheckStride (op, s, " stride <= MAX_STRIDE " ))) {
485+ if (failed (levelCheckStride (op, s, " stride" ))) {
478486 return failure ();
479487 }
480488 }
@@ -494,10 +502,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
494502 const int64_t scaleYD = scale[1 ];
495503 const int64_t scaleXN = scale[2 ];
496504 const int64_t scaleXD = scale[3 ];
497- if (failed (levelCheckScale (op, scaleYN / scaleYD,
498- " scale_y_n/scale_y_d <= MAX_SCALE " )) ||
499- failed (levelCheckScale (op, scaleXN / scaleXD,
500- " scale_x_n/scale_x_d <= MAX_SCALE " ))) {
505+ if (failed (
506+ levelCheckScale (op, scaleYN / scaleYD, " scale_y_n/scale_y_d" )) ||
507+ failed (
508+ levelCheckScale (op, scaleXN / scaleXD, " scale_x_n/scale_x_d" ))) {
501509 return failure ();
502510 }
503511 }
@@ -524,11 +532,11 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
524532 int32_t maxNestedDepth = 0 ;
525533 getMaxNestedDepth (op, maxNestedDepth);
526534
527- if (maxNestedDepth > = targetEnv.getLevel ().MAX_NESTING ) {
528- op-> emitOpError () << " failed level check: " << maxNestedDepth
529- << " >= MAX_NESTING " ;
530- return failure ();
531- }
535+ const int32_t maxNestingLevel = targetEnv.getLevel ().MAX_NESTING ;
536+ if (maxNestedDepth >= maxNestingLevel)
537+ return op-> emitOpError ()
538+ << " failed level check: tosa_nesting_depth < MAX_NESTING " << " ( "
539+ << maxNestingLevel << " ), got " << maxNestedDepth;
532540 return success ();
533541 }
534542
0 commit comments