@@ -404,14 +404,21 @@ class ConvertAtenReflectionPad2dOp
404
404
Value hDimSize = inputShape[hDim];
405
405
Value vDimSize = inputShape[vDim];
406
406
407
- assert (getHPadArgument (LEFT) < inputType.getShape ()[hDim] &&
408
- " Left padding too large" );
409
- assert (getHPadArgument (RIGHT) < inputType.getShape ()[hDim] &&
410
- " Right padding too large" );
411
- assert (getVPadArgument (TOP) < inputType.getShape ()[vDim] &&
412
- " Top padding too large" );
413
- assert (getVPadArgument (BOTTOM) < inputType.getShape ()[vDim] &&
414
- " Bottom padding too large" );
407
+ auto verifyPadding = [&](int64_t padArgument, int64_t dim,
408
+ StringRef errorMessage) {
409
+ auto padValue = rewriter.create <arith::ConstantIndexOp>(loc, padArgument);
410
+ Value index = rewriter.create <arith::ConstantIndexOp>(loc, dim);
411
+ Value shapeDim = rewriter.create <tensor::DimOp>(loc, input, index);
412
+ Value cmpPred = rewriter.create <arith::CmpIOp>(
413
+ loc, arith::CmpIPredicate::sle, padValue, shapeDim);
414
+ rewriter.create <cf::AssertOp>(loc, cmpPred,
415
+ rewriter.getStringAttr (errorMessage));
416
+ };
417
+
418
+ verifyPadding (getHPadArgument (LEFT), hDim, " Left padding too large" );
419
+ verifyPadding (getHPadArgument (RIGHT), hDim, " Right padding too large" );
420
+ verifyPadding (getVPadArgument (TOP), vDim, " Top padding too large" );
421
+ verifyPadding (getVPadArgument (BOTTOM), vDim, " Bottom padding too large" );
415
422
416
423
Type indexType = rewriter.getIndexType ();
417
424
Value zero = getConstant (rewriter, loc, 0 , indexType);
0 commit comments