Skip to content

Commit bbf652c

Browse files
committed
[TorchToLinalg][GridSample] Add support for border padding mode
1 parent 95f7781 commit bbf652c

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,10 +2568,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25682568
return res;
25692569
};
25702570

2571+
auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2572+
Value SizeSubOne) -> Value {
2573+
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat);
2574+
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2575+
};
2576+
2577+
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2578+
Value x, Value SizeSubOne) -> Value {
2579+
Value border = lambdaBorder(b, loc, x, SizeSubOne);
2580+
Value zeroInt =
2581+
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
2582+
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2583+
paddingMode, zeroInt);
2584+
2585+
return b.create<arith::SelectOp>(loc, isZero, x, border);
2586+
};
2587+
25712588
auto resultType = cast<RankedTensorType>(
25722589
getTypeConverter()->convertType(op.getResult().getType()));
25732590
Value alignCorners = adaptor.getAlignCorners();
25742591
Value interMode = adaptor.getInterpolationMode();
2592+
Value paddingMode = adaptor.getPaddingMode();
2593+
25752594
SmallVector<Value> dynamicSizes{};
25762595
if (resultType.isDynamicDim(0))
25772596
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
@@ -2599,10 +2618,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25992618
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
26002619
Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
26012620
Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
2602-
Value result0 =
2621+
Value unnorm0 =
26032622
b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
2604-
Value result1 =
2623+
Value unnorm1 =
26052624
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
2625+
Value result0 =
2626+
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
2627+
Value result1 =
2628+
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
26062629
Value checkLowerBound0 = b.create<arith::CmpFOp>(
26072630
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
26082631
Value checkLowerBound1 = b.create<arith::CmpFOp>(

0 commit comments

Comments
 (0)