Skip to content

Commit 793356a

Browse files
lhutton1HugoSilvaSantos
authored andcommitted
[mlir][tosa] Improve level check error messages (#174980)
This commit adds expected and actual values to the level check error messages, making it easier for a user to diagnose issues.
1 parent e19738d commit 793356a

File tree

2 files changed

+113
-105
lines changed

2 files changed

+113
-105
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Pass/Pass.h"
2828
#include "mlir/Transforms/DialectConversion.h"
2929
#include "llvm/ADT/StringExtras.h"
30+
#include "llvm/Support/FormatVariadic.h"
3031

3132
namespace mlir {
3233
namespace tosa {
@@ -189,33 +190,40 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
189190
constCheckers.emplace_back(checkConstantOperandSilceShape);
190191
}
191192

192-
LogicalResult levelCheckKernel(Operation *op, int32_t v,
193-
const StringRef checkDesc) {
194-
if (v > targetEnv.getLevel().MAX_KERNEL)
195-
return op->emitOpError() << "failed level check: " << checkDesc;
193+
LogicalResult levelCheck(Operation *op, const int32_t calculatedValue,
194+
const int32_t maxLevel, const StringRef inputName,
195+
const StringRef levelName) {
196+
if (calculatedValue > maxLevel)
197+
return op->emitOpError()
198+
<< "failed level check: " << inputName << " <= " << levelName
199+
<< " (" << maxLevel << "), got " << calculatedValue;
196200
return success();
197201
}
198202

203+
LogicalResult levelCheckKernel(Operation *op, int32_t v,
204+
const StringRef inputName) {
205+
return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
206+
"MAX_KERNEL");
207+
}
208+
199209
LogicalResult levelCheckStride(Operation *op, int32_t v,
200-
const StringRef checkDesc) {
201-
if (v > targetEnv.getLevel().MAX_STRIDE)
202-
return op->emitOpError() << "failed level check: " << checkDesc;
203-
return success();
210+
const StringRef inputName) {
211+
return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
212+
"MAX_STRIDE");
204213
}
205214

206215
LogicalResult levelCheckScale(Operation *op, int32_t v,
207-
const StringRef checkDesc) {
208-
if (v > targetEnv.getLevel().MAX_SCALE)
209-
return op->emitOpError() << "failed level check: " << checkDesc;
210-
return success();
216+
const StringRef inputName) {
217+
return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
218+
"MAX_SCALE");
211219
}
212220

213221
LogicalResult levelCheckListSize(Operation *op, int32_t v,
214-
const StringRef checkDesc) {
215-
if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
216-
return op->emitOpError()
217-
<< "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
218-
return success();
222+
const StringRef inputName) {
223+
const std::string inputDesc =
224+
llvm::formatv("length(tensor_list_shape({0}))", inputName);
225+
return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
226+
inputDesc, "MAX_TENSOR_LIST_SIZE");
219227
}
220228

221229
// Perform the Level Rank check on the tensor type.
@@ -317,17 +325,17 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
317325
LogicalResult levelCheckPool(Operation *op) {
318326
if (auto poolOp = dyn_cast<T>(op)) {
319327
for (auto k : poolOp.getKernel()) {
320-
if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
328+
if (failed(levelCheckKernel(op, k, "kernel"))) {
321329
return failure();
322330
}
323331
}
324332
for (auto s : poolOp.getStride()) {
325-
if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
333+
if (failed(levelCheckStride(op, s, "stride"))) {
326334
return failure();
327335
}
328336
}
329337
for (auto p : poolOp.getPad()) {
330-
if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
338+
if (failed(levelCheckKernel(op, p, "pad"))) {
331339
return failure();
332340
}
333341
}
@@ -341,17 +349,17 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
341349
if (auto convOp = dyn_cast<T>(op)) {
342350

343351
for (auto k : convOp.getDilation()) {
344-
if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
352+
if (failed(levelCheckKernel(op, k, "dilation"))) {
345353
return failure();
346354
}
347355
}
348356
for (auto p : convOp.getPad()) {
349-
if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
357+
if (failed(levelCheckKernel(op, p, "pad"))) {
350358
return failure();
351359
}
352360
}
353361
for (auto s : convOp.getStride()) {
354-
if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
362+
if (failed(levelCheckStride(op, s, "stride"))) {
355363
return failure();
356364
}
357365
}
@@ -363,27 +371,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
363371
assert(shape.size() == 4);
364372
assert(dilation.size() == 2);
365373
if (failed(levelCheckKernel(op, dilation[0] * shape[1],
366-
"dilation_y * KH <= MAX_KERNEL)")) ||
374+
"dilation_y * KH")) ||
367375
failed(levelCheckKernel(op, dilation[1] * shape[2],
368-
"dilation_x * KW <= MAX_KERNEL)")))
376+
"dilation_x * KW")))
369377
return failure();
370378
} else if (isa<tosa::Conv3DOp>(op)) {
371379
assert(shape.size() == 5);
372380
assert(dilation.size() == 3);
373381
if (failed(levelCheckKernel(op, dilation[0] * shape[1],
374-
"dilation_d * KD <= MAX_KERNEL)")) ||
382+
"dilation_d * KD")) ||
375383
failed(levelCheckKernel(op, dilation[1] * shape[2],
376-
"dilation_y * KH <= MAX_KERNEL)")) ||
384+
"dilation_y * KH")) ||
377385
failed(levelCheckKernel(op, dilation[2] * shape[3],
378-
"dilation_x * KW <= MAX_KERNEL)")))
386+
"dilation_x * KW")))
379387
return failure();
380388
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
381389
assert(shape.size() == 4);
382390
assert(dilation.size() == 2);
383391
if (failed(levelCheckKernel(op, dilation[0] * shape[0],
384-
"dilation_y * KH <= MAX_KERNEL)")) ||
392+
"dilation_y * KH")) ||
385393
failed(levelCheckKernel(op, dilation[1] * shape[1],
386-
"dilation_x * KW <= MAX_KERNEL)")))
394+
"dilation_x * KW")))
387395
return failure();
388396
}
389397
}
@@ -445,8 +453,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
445453
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
446454
auto shape = type.getShape();
447455
assert(shape.size() == 3);
448-
if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
449-
failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
456+
if (failed(levelCheckKernel(op, shape[1], "H")) ||
457+
failed(levelCheckKernel(op, shape[2], "W"))) {
450458
return failure();
451459
}
452460
}
@@ -463,18 +471,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
463471
auto shape = filterType.getShape();
464472
assert(shape.size() == 4);
465473
// level check kernel sizes for kH and KW
466-
if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
467-
failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
474+
if (failed(levelCheckKernel(op, shape[1], "KH")) ||
475+
failed(levelCheckKernel(op, shape[2], "KW"))) {
468476
return failure();
469477
}
470478
}
471479
for (auto p : transpose.getOutPad()) {
472-
if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
480+
if (failed(levelCheckKernel(op, p, "pad"))) {
473481
return failure();
474482
}
475483
}
476484
for (auto s : transpose.getStride()) {
477-
if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
485+
if (failed(levelCheckStride(op, s, "stride"))) {
478486
return failure();
479487
}
480488
}
@@ -494,10 +502,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
494502
const int64_t scaleYD = scale[1];
495503
const int64_t scaleXN = scale[2];
496504
const int64_t scaleXD = scale[3];
497-
if (failed(levelCheckScale(op, scaleYN / scaleYD,
498-
"scale_y_n/scale_y_d <= MAX_SCALE")) ||
499-
failed(levelCheckScale(op, scaleXN / scaleXD,
500-
"scale_x_n/scale_x_d <= MAX_SCALE"))) {
505+
if (failed(
506+
levelCheckScale(op, scaleYN / scaleYD, "scale_y_n/scale_y_d")) ||
507+
failed(
508+
levelCheckScale(op, scaleXN / scaleXD, "scale_x_n/scale_x_d"))) {
501509
return failure();
502510
}
503511
}
@@ -524,11 +532,11 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
524532
int32_t maxNestedDepth = 0;
525533
getMaxNestedDepth(op, maxNestedDepth);
526534

527-
if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
528-
op->emitOpError() << "failed level check: " << maxNestedDepth
529-
<< " >= MAX_NESTING";
530-
return failure();
531-
}
535+
const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
536+
if (maxNestedDepth >= maxNestingLevel)
537+
return op->emitOpError()
538+
<< "failed level check: tosa_nesting_depth < MAX_NESTING" << " ("
539+
<< maxNestingLevel << "), got " << maxNestedDepth;
532540
return success();
533541
}
534542

0 commit comments

Comments
 (0)