|
18 | 18 |
|
19 | 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | 20 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 21 | +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
21 | 22 | #include "mlir/IR/Builders.h" |
22 | 23 | #include "mlir/IR/BuiltinOps.h" |
23 | 24 | #include "mlir/IR/Matchers.h" |
@@ -119,6 +120,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { |
119 | 120 | // check variable read/write data types against variable declarations |
120 | 121 | LogicalResult applyVariableCheck(Operation *op); |
121 | 122 |
|
| 123 | + // check error if conditions |
| 124 | + LogicalResult applyErrorIfCheck(Operation *op); |
| 125 | + |
122 | 126 | private: |
123 | 127 | void populateConstantOperandChecks() { |
124 | 128 | constCheckers.emplace_back(checkConstantOperandPad); |
@@ -383,11 +387,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { |
383 | 387 | // Resize op: level check max scales |
384 | 388 | bool levelCheckResize(Operation *op) { |
385 | 389 | if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { |
386 | | - auto scale = resize.getScale(); |
387 | | - int16_t scaleYN = scale[0]; |
388 | | - int16_t scaleYD = scale[1]; |
389 | | - int16_t scaleXN = scale[2]; |
390 | | - int16_t scaleXD = scale[3]; |
| 390 | + SmallVector<int64_t> scale; |
| 391 | + if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) { |
| 392 | + return false; |
| 393 | + } |
| 394 | + const int64_t scaleYN = scale[0]; |
| 395 | + const int64_t scaleYD = scale[1]; |
| 396 | + const int64_t scaleXN = scale[2]; |
| 397 | + const int64_t scaleXD = scale[3]; |
391 | 398 | if (!levelCheckScale(op, scaleYN / scaleYD, |
392 | 399 | "scale_y_n/scale_y_d <= MAX_SCALE") || |
393 | 400 | !levelCheckScale(op, scaleXN / scaleXD, |
@@ -519,6 +526,106 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) { |
519 | 526 | return success(); |
520 | 527 | } |
521 | 528 |
|
| 529 | +bool checkErrorIfResize(Operation *op) { |
| 530 | + if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { |
| 531 | + const Value input = resize.getInput(); |
| 532 | + const Value output = resize.getOutput(); |
| 533 | + const RankedTensorType inputType = |
| 534 | + llvm::dyn_cast<RankedTensorType>(input.getType()); |
| 535 | + const RankedTensorType outputType = |
| 536 | + llvm::dyn_cast<RankedTensorType>(output.getType()); |
| 537 | + |
| 538 | + if (!inputType || !outputType) { |
| 539 | + op->emitOpError("expect ranked input/output tensor"); |
| 540 | + return false; |
| 541 | + } |
| 542 | + |
| 543 | + // Ensure the image size is supported by GPU APIs and that for integer |
| 544 | + // implementations, position * stride does not overflow int32_t. |
| 545 | + if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
| 546 | + const SmallVector<int64_t, 4> sizes = { |
| 547 | + outputType.getDimSize(1), outputType.getDimSize(2), |
| 548 | + inputType.getDimSize(1), inputType.getDimSize(2)}; |
| 549 | + const int64_t *maxDim = llvm::max_element(sizes); |
| 550 | + if (maxDim != sizes.end() && *maxDim >= 16384) { |
| 551 | + op->emitOpError("expect input/output height/width dims to be < 16384, ") |
| 552 | + << "got [OH, OW, IH, IW] = " << sizes; |
| 553 | + return false; |
| 554 | + } |
| 555 | + } |
| 556 | + |
| 557 | + SmallVector<int64_t> scale; |
| 558 | + if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) { |
| 559 | + return false; |
| 560 | + } |
| 561 | + |
| 562 | + const int64_t scaleYN = scale[0]; |
| 563 | + const int64_t scaleYD = scale[1]; |
| 564 | + const int64_t scaleXN = scale[2]; |
| 565 | + const int64_t scaleXD = scale[3]; |
| 566 | + |
| 567 | + // Ensure scale values don't overflow int32 accumulator |
| 568 | + if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) { |
| 569 | + op->emitOpError("expect all scale numerator values to be <= (1 << 11), " |
| 570 | + "got scale_y_n=") |
| 571 | + << scaleYN << ", scale_x_n=" << scaleXN; |
| 572 | + return false; |
| 573 | + } |
| 574 | + |
| 575 | + if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) { |
| 576 | + op->emitOpError("expect a downscale ratio larger than 1/16, got y=") |
| 577 | + << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD; |
| 578 | + return false; |
| 579 | + } |
| 580 | + |
| 581 | + SmallVector<int64_t> offset; |
| 582 | + SmallVector<int64_t> border; |
| 583 | + if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) || |
| 584 | + !tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) { |
| 585 | + return false; |
| 586 | + } |
| 587 | + |
| 588 | + const int64_t offsetY = offset[0]; |
| 589 | + const int64_t offsetX = offset[1]; |
| 590 | + const int64_t borderY = border[0]; |
| 591 | + const int64_t borderX = border[1]; |
| 592 | + |
| 593 | + // Set a consistent lower limit of 1/16 downscale to simplify |
| 594 | + // implementations |
| 595 | + if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) { |
| 596 | + op->emitOpError( |
| 597 | + "expect offsetY / scaleYNumerator to be in range [-1, 16), got ") |
| 598 | + << offsetY << "/" << scaleYN; |
| 599 | + return false; |
| 600 | + } |
| 601 | + if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) { |
| 602 | + op->emitOpError( |
| 603 | + "expect offsetX / scaleXNumerator to be in range [-1, 16), got ") |
| 604 | + << offsetX << "/" << scaleXN; |
| 605 | + return false; |
| 606 | + } |
| 607 | + if (borderY < -16 * scaleYN || borderY >= scaleYN) { |
| 608 | + op->emitOpError( |
| 609 | + "expect borderY / scaleYNumerator to be in range [-16, 1), got ") |
| 610 | + << borderY << "/" << scaleYN; |
| 611 | + return false; |
| 612 | + } |
| 613 | + if (borderX < -16 * scaleXN || borderX >= scaleXN) { |
| 614 | + op->emitOpError( |
| 615 | + "expect borderX / scaleXNumerator to be in range [-16, 1), got ") |
| 616 | + << borderX << "/" << scaleXN; |
| 617 | + return false; |
| 618 | + } |
| 619 | + } |
| 620 | + return true; |
| 621 | +} |
| 622 | + |
| 623 | +LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { |
| 624 | + if (!checkErrorIfResize(op)) |
| 625 | + return failure(); |
| 626 | + return success(); |
| 627 | +} |
| 628 | + |
522 | 629 | bool TosaValidation::isValidElementType(Type type) { |
523 | 630 | if (isa<FloatType>(type)) { |
524 | 631 | if (!isEnabledProfile(TosaProfileEnum::MainInference)) |
@@ -582,6 +689,10 @@ void TosaValidation::runOnOperation() { |
582 | 689 | // do variable type checks |
583 | 690 | if (failed(applyVariableCheck(op))) |
584 | 691 | signalPassFailure(); |
| 692 | + |
| 693 | + // do error if checks |
| 694 | + if (StrictOperationSpecAlignment && failed(applyErrorIfCheck(op))) |
| 695 | + signalPassFailure(); |
585 | 696 | }); |
586 | 697 | } |
587 | 698 | } // namespace |
0 commit comments