@@ -2593,22 +2593,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2593
2593
return b.create <arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2594
2594
};
2595
2595
2596
- auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2596
+ auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
2597
2597
Value x, Value SizeSubOne) -> Value {
2598
- Value border = lambdaBorder (b, loc, x, SizeSubOne);
2599
- Value zeroInt =
2600
- b.create <arith::ConstantOp>(loc, b.getIntegerAttr (int64type, 0 ));
2601
- Value isZero = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2602
- paddingMode, zeroInt);
2598
+ // Border
2599
+ if (paddingMode == 1 ) {
2600
+ return lambdaBorder (b, loc, x, SizeSubOne);
2601
+ }
2603
2602
2604
- return b. create <arith::SelectOp>(loc, isZero, x, border) ;
2603
+ return x ;
2605
2604
};
2606
2605
2607
2606
auto resultType = cast<RankedTensorType>(
2608
2607
getTypeConverter ()->convertType (op.getResult ().getType ()));
2609
2608
Value alignCorners = adaptor.getAlignCorners ();
2610
2609
Value interMode = adaptor.getInterpolationMode ();
2611
- Value paddingMode = adaptor.getPaddingMode ();
2610
+
2611
+ int64_t paddingModeInt;
2612
+ matchPattern (op.getPaddingMode (), m_TorchConstantInt (&paddingModeInt));
2612
2613
2613
2614
SmallVector<Value> dynamicSizes{};
2614
2615
if (resultType.isDynamicDim (0 ))
@@ -2642,9 +2643,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2642
2643
Value unnorm1 =
2643
2644
b.create <arith::AddFOp>(loc, gPlusMul1 , gr1HalfSelect);
2644
2645
Value result0 =
2645
- lambdaPadding (b, loc, paddingMode , unnorm0, innerDim0d);
2646
+ lambdaPadding (b, loc, paddingModeInt , unnorm0, innerDim0d);
2646
2647
Value result1 =
2647
- lambdaPadding (b, loc, paddingMode , unnorm1, innerDim1d);
2648
+ lambdaPadding (b, loc, paddingModeInt , unnorm1, innerDim1d);
2648
2649
Value checkLowerBound0 = b.create <arith::CmpFOp>(
2649
2650
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
2650
2651
Value checkLowerBound1 = b.create <arith::CmpFOp>(
0 commit comments