Skip to content

Commit 32aff8c

Browse files
Add runtime asserts for aten.reflection.pad_2d (#4057)
Add runtime asserts to check padding constraints of aten.reflection.pad_2d for dynamic dims
1 parent 1aacb46 commit 32aff8c

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,21 @@ class ConvertAtenReflectionPad2dOp
404404
Value hDimSize = inputShape[hDim];
405405
Value vDimSize = inputShape[vDim];
406406

407-
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
408-
"Left padding too large");
409-
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
410-
"Right padding too large");
411-
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
412-
"Top padding too large");
413-
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
414-
"Bottom padding too large");
407+
auto verifyPadding = [&](int64_t padArgument, int64_t dim,
408+
StringRef errorMessage) {
409+
auto padValue = rewriter.create<arith::ConstantIndexOp>(loc, padArgument);
410+
Value index = rewriter.create<arith::ConstantIndexOp>(loc, dim);
411+
Value shapeDim = rewriter.create<tensor::DimOp>(loc, input, index);
412+
Value cmpPred = rewriter.create<arith::CmpIOp>(
413+
loc, arith::CmpIPredicate::sle, padValue, shapeDim);
414+
rewriter.create<cf::AssertOp>(loc, cmpPred,
415+
rewriter.getStringAttr(errorMessage));
416+
};
417+
418+
verifyPadding(getHPadArgument(LEFT), hDim, "Left padding too large");
419+
verifyPadding(getHPadArgument(RIGHT), hDim, "Right padding too large");
420+
verifyPadding(getVPadArgument(TOP), vDim, "Top padding too large");
421+
verifyPadding(getVPadArgument(BOTTOM), vDim, "Bottom padding too large");
415422

416423
Type indexType = rewriter.getIndexType();
417424
Value zero = getConstant(rewriter, loc, 0, indexType);

0 commit comments

Comments
 (0)