@@ -2587,10 +2587,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2587
2587
return res;
2588
2588
};
2589
2589
2590
+ auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2591
+ Value SizeSubOne) -> Value {
2592
+ Value xMaxZero = b.create <arith::MaximumFOp>(loc, x, zeroFloat);
2593
+ return b.create <arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2594
+ };
2595
+
2596
+ auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2597
+ Value x, Value SizeSubOne) -> Value {
2598
+ Value border = lambdaBorder (b, loc, x, SizeSubOne);
2599
+ Value zeroInt =
2600
+ b.create <arith::ConstantOp>(loc, b.getIntegerAttr (int64type, 0 ));
2601
+ Value isZero = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2602
+ paddingMode, zeroInt);
2603
+
2604
+ return b.create <arith::SelectOp>(loc, isZero, x, border);
2605
+ };
2606
+
2590
2607
auto resultType = cast<RankedTensorType>(
2591
2608
getTypeConverter ()->convertType (op.getResult ().getType ()));
2592
2609
Value alignCorners = adaptor.getAlignCorners ();
2593
2610
Value interMode = adaptor.getInterpolationMode ();
2611
+ Value paddingMode = adaptor.getPaddingMode ();
2612
+
2594
2613
SmallVector<Value> dynamicSizes{};
2595
2614
if (resultType.isDynamicDim (0 ))
2596
2615
dynamicSizes.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
@@ -2618,10 +2637,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2618
2637
Value gplus1 = b.create <arith::AddFOp>(loc, gr1, oneFloat);
2619
2638
Value gPlusMul0 = b.create <arith::MulFOp>(loc, gplus0, innerDim0e);
2620
2639
Value gPlusMul1 = b.create <arith::MulFOp>(loc, gplus1, innerDim1e);
2621
- Value result0 =
2640
+ Value unnorm0 =
2622
2641
b.create <arith::AddFOp>(loc, gPlusMul0 , gr0HalfSelect);
2623
- Value result1 =
2642
+ Value unnorm1 =
2624
2643
b.create <arith::AddFOp>(loc, gPlusMul1 , gr1HalfSelect);
2644
+ Value result0 =
2645
+ lambdaPadding (b, loc, paddingMode, unnorm0, innerDim0d);
2646
+ Value result1 =
2647
+ lambdaPadding (b, loc, paddingMode, unnorm1, innerDim1d);
2625
2648
Value checkLowerBound0 = b.create <arith::CmpFOp>(
2626
2649
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
2627
2650
Value checkLowerBound1 = b.create <arith::CmpFOp>(
0 commit comments