Skip to content

Commit 216a024

Browse files
committed
[TorchToLinalg][GridSample] Add support for border padding mode
1 parent 1259e8a commit 216a024

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,10 +2565,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25652565
return res;
25662566
};
25672567

2568+
auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2569+
Value SizeSubOne) -> Value {
2570+
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat);
2571+
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2572+
};
2573+
2574+
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2575+
Value x, Value SizeSubOne) -> Value {
2576+
Value border = lambdaBorder(b, loc, x, SizeSubOne);
2577+
Value zeroInt =
2578+
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
2579+
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2580+
paddingMode, zeroInt);
2581+
2582+
return b.create<arith::SelectOp>(loc, isZero, x, border);
2583+
};
2584+
25682585
auto resultType = cast<RankedTensorType>(
25692586
getTypeConverter()->convertType(op.getResult().getType()));
25702587
Value alignCorners = adaptor.getAlignCorners();
25712588
Value interMode = adaptor.getInterpolationMode();
2589+
Value paddingMode = adaptor.getPaddingMode();
2590+
25722591
SmallVector<Value> dynamicSizes{};
25732592
if (resultType.isDynamicDim(0))
25742593
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
@@ -2596,10 +2615,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25962615
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
25972616
Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
25982617
Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
2599-
Value result0 =
2618+
Value unnorm0 =
26002619
b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
2601-
Value result1 =
2620+
Value unnorm1 =
26022621
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
2622+
Value result0 =
2623+
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
2624+
Value result1 =
2625+
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
26032626
Value checkLowerBound0 = b.create<arith::CmpFOp>(
26042627
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
26052628
Value checkLowerBound1 = b.create<arith::CmpFOp>(

0 commit comments

Comments
 (0)