Skip to content

Commit 2b0edc0

Browse files
committed
[TorchToLinalg][GridSample] Add support for border padding mode
1 parent 389541f commit 2b0edc0

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,10 +2587,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25872587
return res;
25882588
};
25892589

2590+
auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2591+
Value SizeSubOne) -> Value {
2592+
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat);
2593+
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2594+
};
2595+
2596+
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2597+
Value x, Value SizeSubOne) -> Value {
2598+
Value border = lambdaBorder(b, loc, x, SizeSubOne);
2599+
Value zeroInt =
2600+
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
2601+
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2602+
paddingMode, zeroInt);
2603+
2604+
return b.create<arith::SelectOp>(loc, isZero, x, border);
2605+
};
2606+
25902607
auto resultType = cast<RankedTensorType>(
25912608
getTypeConverter()->convertType(op.getResult().getType()));
25922609
Value alignCorners = adaptor.getAlignCorners();
25932610
Value interMode = adaptor.getInterpolationMode();
2611+
Value paddingMode = adaptor.getPaddingMode();
2612+
25942613
SmallVector<Value> dynamicSizes{};
25952614
if (resultType.isDynamicDim(0))
25962615
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
@@ -2618,10 +2637,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
26182637
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
26192638
Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
26202639
Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
2621-
Value result0 =
2640+
Value unnorm0 =
26222641
b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
2623-
Value result1 =
2642+
Value unnorm1 =
26242643
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
2644+
Value result0 =
2645+
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
2646+
Value result1 =
2647+
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
26252648
Value checkLowerBound0 = b.create<arith::CmpFOp>(
26262649
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
26272650
Value checkLowerBound1 = b.create<arith::CmpFOp>(

0 commit comments

Comments
 (0)