@@ -2565,10 +2565,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2565
2565
return res;
2566
2566
};
2567
2567
2568
+ auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2569
+ Value SizeSubOne) -> Value {
2570
+ Value xMaxZero = b.create <arith::MaximumFOp>(loc, x, zeroFloat);
2571
+ return b.create <arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2572
+ };
2573
+
2574
+ auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2575
+ Value x, Value SizeSubOne) -> Value {
2576
+ Value border = lambdaBorder (b, loc, x, SizeSubOne);
2577
+ Value zeroInt =
2578
+ b.create <arith::ConstantOp>(loc, b.getIntegerAttr (int64type, 0 ));
2579
+ Value isZero = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2580
+ paddingMode, zeroInt);
2581
+
2582
+ return b.create <arith::SelectOp>(loc, isZero, x, border);
2583
+ };
2584
+
2568
2585
auto resultType = cast<RankedTensorType>(
2569
2586
getTypeConverter ()->convertType (op.getResult ().getType ()));
2570
2587
Value alignCorners = adaptor.getAlignCorners ();
2571
2588
Value interMode = adaptor.getInterpolationMode ();
2589
+ Value paddingMode = adaptor.getPaddingMode ();
2590
+
2572
2591
SmallVector<Value> dynamicSizes{};
2573
2592
if (resultType.isDynamicDim (0 ))
2574
2593
dynamicSizes.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
@@ -2596,10 +2615,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2596
2615
Value gplus1 = b.create <arith::AddFOp>(loc, gr1, oneFloat);
2597
2616
Value gPlusMul0 = b.create <arith::MulFOp>(loc, gplus0, innerDim0e);
2598
2617
Value gPlusMul1 = b.create <arith::MulFOp>(loc, gplus1, innerDim1e);
2599
- Value result0 =
2618
+ Value unnorm0 =
2600
2619
b.create <arith::AddFOp>(loc, gPlusMul0 , gr0HalfSelect);
2601
- Value result1 =
2620
+ Value unnorm1 =
2602
2621
b.create <arith::AddFOp>(loc, gPlusMul1 , gr1HalfSelect);
2622
+ Value result0 =
2623
+ lambdaPadding (b, loc, paddingMode, unnorm0, innerDim0d);
2624
+ Value result1 =
2625
+ lambdaPadding (b, loc, paddingMode, unnorm1, innerDim1d);
2603
2626
Value checkLowerBound0 = b.create <arith::CmpFOp>(
2604
2627
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
2605
2628
Value checkLowerBound1 = b.create <arith::CmpFOp>(
0 commit comments