Skip to content

Commit cfed598

Browse files
committed
Simplify paddingMode lowering
Evaluate paddingMode at compile time
1 parent e9659f8 commit cfed598

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
163163
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
164164

165165
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
166-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
167-
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
168-
paddingModeInt));
166+
binder.getLoc(), paddingModeInt);
169167

170168
bool alignMode = align;
171169
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,22 +2593,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25932593
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
25942594
};
25952595

2596-
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2596+
auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
25972597
Value x, Value SizeSubOne) -> Value {
2598-
Value border = lambdaBorder(b, loc, x, SizeSubOne);
2599-
Value zeroInt =
2600-
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
2601-
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2602-
paddingMode, zeroInt);
2598+
// Border
2599+
if (paddingMode == 1) {
2600+
return lambdaBorder(b, loc, x, SizeSubOne);
2601+
}
26032602

2604-
return b.create<arith::SelectOp>(loc, isZero, x, border);
2603+
return x;
26052604
};
26062605

26072606
auto resultType = cast<RankedTensorType>(
26082607
getTypeConverter()->convertType(op.getResult().getType()));
26092608
Value alignCorners = adaptor.getAlignCorners();
26102609
Value interMode = adaptor.getInterpolationMode();
2611-
Value paddingMode = adaptor.getPaddingMode();
2610+
2611+
int64_t paddingModeInt;
2612+
matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt));
26122613

26132614
SmallVector<Value> dynamicSizes{};
26142615
if (resultType.isDynamicDim(0))
@@ -2642,9 +2643,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
26422643
Value unnorm1 =
26432644
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
26442645
Value result0 =
2645-
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
2646+
lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d);
26462647
Value result1 =
2647-
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
2648+
lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d);
26482649
Value checkLowerBound0 = b.create<arith::CmpFOp>(
26492650
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
26502651
Value checkLowerBound1 = b.create<arith::CmpFOp>(

0 commit comments

Comments
 (0)