@@ -2568,10 +2568,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2568
2568
return res;
2569
2569
};
2570
2570
2571
+ auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
2572
+ Value SizeSubOne) -> Value {
2573
+ Value xMaxZero = b.create <arith::MaximumFOp>(loc, x, zeroFloat);
2574
+ return b.create <arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2575
+ };
2576
+
2577
+ auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2578
+ Value x, Value SizeSubOne) -> Value {
2579
+ Value border = lambdaBorder (b, loc, x, SizeSubOne);
2580
+ Value zeroInt =
2581
+ b.create <arith::ConstantOp>(loc, b.getIntegerAttr (int64type, 0 ));
2582
+ Value isZero = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2583
+ paddingMode, zeroInt);
2584
+
2585
+ return b.create <arith::SelectOp>(loc, isZero, x, border);
2586
+ };
2587
+
2571
2588
auto resultType = cast<RankedTensorType>(
2572
2589
getTypeConverter ()->convertType (op.getResult ().getType ()));
2573
2590
Value alignCorners = adaptor.getAlignCorners ();
2574
2591
Value interMode = adaptor.getInterpolationMode ();
2592
+ Value paddingMode = adaptor.getPaddingMode ();
2593
+
2575
2594
SmallVector<Value> dynamicSizes{};
2576
2595
if (resultType.isDynamicDim (0 ))
2577
2596
dynamicSizes.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
@@ -2599,10 +2618,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2599
2618
Value gplus1 = b.create <arith::AddFOp>(loc, gr1, oneFloat);
2600
2619
Value gPlusMul0 = b.create <arith::MulFOp>(loc, gplus0, innerDim0e);
2601
2620
Value gPlusMul1 = b.create <arith::MulFOp>(loc, gplus1, innerDim1e);
2602
- Value result0 =
2621
+ Value unnorm0 =
2603
2622
b.create <arith::AddFOp>(loc, gPlusMul0 , gr0HalfSelect);
2604
- Value result1 =
2623
+ Value unnorm1 =
2605
2624
b.create <arith::AddFOp>(loc, gPlusMul1 , gr1HalfSelect);
2625
+ Value result0 =
2626
+ lambdaPadding (b, loc, paddingMode, unnorm0, innerDim0d);
2627
+ Value result1 =
2628
+ lambdaPadding (b, loc, paddingMode, unnorm1, innerDim1d);
2606
2629
Value checkLowerBound0 = b.create <arith::CmpFOp>(
2607
2630
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
2608
2631
Value checkLowerBound1 = b.create <arith::CmpFOp>(
0 commit comments