@@ -527,96 +527,99 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
527527}
528528
529529bool checkErrorIfResize (Operation *op) {
530- if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
531- const Value input = resize.getInput ();
532- const Value output = resize.getOutput ();
533- const RankedTensorType inputType =
534- llvm::dyn_cast<RankedTensorType>(input.getType ());
535- const RankedTensorType outputType =
536- llvm::dyn_cast<RankedTensorType>(output.getType ());
537-
538- if (!inputType || !outputType) {
539- op->emitOpError (" expect ranked input/output tensor" );
540- return false ;
541- }
530+ auto resize = dyn_cast<tosa::ResizeOp>(op);
531+ if (!resize)
532+ return true ;
542533
543- // Ensure the image size is supported by GPU APIs and that for integer
544- // implementations, position * stride does not overflow int32_t.
545- if (inputType.hasStaticShape () && outputType.hasStaticShape ()) {
546- const SmallVector<int64_t , 4 > sizes = {
547- outputType.getDimSize (1 ), outputType.getDimSize (2 ),
548- inputType.getDimSize (1 ), inputType.getDimSize (2 )};
549- const int64_t *maxDim = llvm::max_element (sizes);
550- if (maxDim != sizes.end () && *maxDim >= 16384 ) {
551- op->emitOpError (" expect input/output height/width dims to be < 16384, " )
552- << " got [OH, OW, IH, IW] = " << sizes;
553- return false ;
554- }
555- }
534+ const Value input = resize.getInput ();
535+ const Value output = resize.getOutput ();
536+ const RankedTensorType inputType =
537+ llvm::dyn_cast<RankedTensorType>(input.getType ());
538+ const RankedTensorType outputType =
539+ llvm::dyn_cast<RankedTensorType>(output.getType ());
556540
557- SmallVector<int64_t > scale;
558- if (!tosa::getConstShapeValue (resize.getScale ().getDefiningOp (), scale)) {
541+ if (!inputType || !outputType) {
542+ op->emitOpError (" expect ranked input/output tensor" );
543+ return false ;
544+ }
545+
546+ // Ensure the image size is supported by GPU APIs and that for integer
547+ // implementations, position * stride does not overflow int32_t.
548+ if (inputType.hasStaticShape () && outputType.hasStaticShape ()) {
549+ const SmallVector<int64_t , 4 > sizes = {
550+ outputType.getDimSize (1 ), outputType.getDimSize (2 ),
551+ inputType.getDimSize (1 ), inputType.getDimSize (2 )};
552+ const int64_t *maxDim = llvm::max_element (sizes);
553+ if (maxDim != sizes.end () && *maxDim >= 16384 ) {
554+ op->emitOpError (" expect input/output height/width dims to be < 16384, " )
555+ << " got [OH, OW, IH, IW] = " << sizes;
559556 return false ;
560557 }
558+ }
561559
562- const int64_t scaleYN = scale[ 0 ] ;
563- const int64_t scaleYD = scale[ 1 ];
564- const int64_t scaleXN = scale[ 2 ] ;
565- const int64_t scaleXD = scale[ 3 ];
560+ SmallVector< int64_t > scale;
561+ if (! tosa::getConstShapeValue (resize. getScale (). getDefiningOp (), scale)) {
562+ return false ;
563+ }
566564
567- // Ensure scale values don't overflow int32 accumulator
568- if (scaleYN > (1 << 11 ) || scaleXN > (1 << 11 )) {
569- op->emitOpError (" expect all scale numerator values to be <= (1 << 11), "
570- " got scale_y_n=" )
571- << scaleYN << " , scale_x_n=" << scaleXN;
572- return false ;
573- }
565+ const int64_t scaleYN = scale[0 ];
566+ const int64_t scaleYD = scale[1 ];
567+ const int64_t scaleXN = scale[2 ];
568+ const int64_t scaleXD = scale[3 ];
574569
575- if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
576- op->emitOpError (" expect a downscale ratio larger than 1/16, got y=" )
577- << scaleYN << " /" << scaleYD << " , x=" << scaleXN << " /" << scaleXD;
578- return false ;
579- }
570+ // Ensure scale values don't overflow int32 accumulator
571+ if (scaleYN > (1 << 11 ) || scaleXN > (1 << 11 )) {
572+ op->emitOpError (" expect all scale numerator values to be <= (1 << 11), "
573+ " got scale_y_n=" )
574+ << scaleYN << " , scale_x_n=" << scaleXN;
575+ return false ;
576+ }
580577
581- SmallVector<int64_t > offset;
582- SmallVector<int64_t > border;
583- if (!tosa::getConstShapeValue (resize.getOffset ().getDefiningOp (), offset) ||
584- !tosa::getConstShapeValue (resize.getBorder ().getDefiningOp (), border)) {
585- return false ;
586- }
578+ if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
579+ op->emitOpError (" expect a downscale ratio larger than 1/16, got y=" )
580+ << scaleYN << " /" << scaleYD << " , x=" << scaleXN << " /" << scaleXD;
581+ return false ;
582+ }
587583
588- const int64_t offsetY = offset[0 ];
589- const int64_t offsetX = offset[1 ];
590- const int64_t borderY = border[0 ];
591- const int64_t borderX = border[1 ];
592-
593- // Set a consistent lower limit of 1/16 downscale to simplify
594- // implementations
595- if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596- op->emitOpError (
597- " expect offsetY / scaleYNumerator to be in range [-1, 16), got " )
598- << offsetY << " /" << scaleYN;
599- return false ;
600- }
601- if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602- op->emitOpError (
603- " expect offsetX / scaleXNumerator to be in range [-1, 16), got " )
604- << offsetX << " /" << scaleXN;
605- return false ;
606- }
607- if (borderY < -16 * scaleYN || borderY >= scaleYN) {
608- op->emitOpError (
609- " expect borderY / scaleYNumerator to be in range [-16, 1), got " )
610- << borderY << " /" << scaleYN;
611- return false ;
612- }
613- if (borderX < -16 * scaleXN || borderX >= scaleXN) {
614- op->emitOpError (
615- " expect borderX / scaleXNumerator to be in range [-16, 1), got " )
616- << borderX << " /" << scaleXN;
617- return false ;
618- }
584+ SmallVector<int64_t > offset;
585+ SmallVector<int64_t > border;
586+ if (!tosa::getConstShapeValue (resize.getOffset ().getDefiningOp (), offset) ||
587+ !tosa::getConstShapeValue (resize.getBorder ().getDefiningOp (), border)) {
588+ return false ;
619589 }
590+
591+ const int64_t offsetY = offset[0 ];
592+ const int64_t offsetX = offset[1 ];
593+ // Set a consistent lower limit of 1/16 downscale to simplify
594+ // implementations
595+ if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596+ op->emitOpError (
597+ " expect offsetY / scaleYNumerator to be in range [-1, 16), got " )
598+ << offsetY << " /" << scaleYN;
599+ return false ;
600+ }
601+ if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602+ op->emitOpError (
603+ " expect offsetX / scaleXNumerator to be in range [-1, 16), got " )
604+ << offsetX << " /" << scaleXN;
605+ return false ;
606+ }
607+
608+ const int64_t borderY = border[0 ];
609+ const int64_t borderX = border[1 ];
610+ if (borderY < -16 * scaleYN || borderY >= scaleYN) {
611+ op->emitOpError (
612+ " expect borderY / scaleYNumerator to be in range [-16, 1), got " )
613+ << borderY << " /" << scaleYN;
614+ return false ;
615+ }
616+ if (borderX < -16 * scaleXN || borderX >= scaleXN) {
617+ op->emitOpError (
618+ " expect borderX / scaleXNumerator to be in range [-16, 1), got " )
619+ << borderX << " /" << scaleXN;
620+ return false ;
621+ }
622+
620623 return true ;
621624}
622625
0 commit comments