@@ -107,6 +107,143 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
107107 return success ();
108108}
109109
110+ // Example:
111+ // input = tensor([[[0., 1., 2., 3.],
112+ // [4., 5., 6., 7.]]])
113+ // torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1
114+ // tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
115+ // [7., 6., 5., 4., 5., 6., 7., 6.]]])
116+ // Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension
117+ // Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension.
118+ // The last dimension of the result tensor should be last dimension of input tensor +
119+ // left padding size + right padding size. INitialize result tensor to all zeros
120+ // b) Setup affine map to take slice from input tensor of size left padding starting from
121+ // second column onwards as first column is reflection boundary
122+ // c) Reflect the affine map to have resultant slice reflected
123+ // d) Take the slice and write from begining in result tensor
124+ // e) write the original tensor next into result tensor
125+ // f) Setup affine map to take slice from input tensor of right padding size ending
126+ // at second last column as last column is reflection boundary for right padding
127+ // g) Reflect the affine map to have resultant slice reflected
128+ // h) Take the slice and write from left padding size + orignal tensor last dim size
129+ // into result tensor
130+ // Uses the ideas/code used for AtenReflectionPad2dOp
131+ namespace {
132+ class ConvertAtenReflectionPad1dOp
133+ : public OpConversionPattern<AtenReflectionPad1dOp> {
134+ public:
135+ using OpConversionPattern::OpConversionPattern;
136+ LogicalResult
137+ matchAndRewrite (AtenReflectionPad1dOp op, OpAdaptor adaptor,
138+ ConversionPatternRewriter &rewriter) const override {
139+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
140+ return failure ();
141+
142+ SmallVector<int64_t > padInts;
143+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
144+ return rewriter.notifyMatchFailure (
145+ op, " only constant int padding range is supported" );
146+
147+ MLIRContext *context = rewriter.getContext ();
148+ Location loc = op.getLoc ();
149+
150+ // Lambda Unitility Functions
151+ // Create an Integer expression of x + y
152+ auto createIAdd = [&](Value x, Value y) {
153+ return rewriter.create <arith::AddIOp>(loc, x, y);
154+ };
155+
156+ // Create an integer expression of x - y
157+ auto createISub = [&](Value x, Value y) {
158+ return rewriter.create <arith::SubIOp>(loc, x, y);
159+ };
160+
161+ enum PadLocation {PAD_LEFT = 0 , PAD_RIGHT = 1 , PAD_CENTER=2 };
162+
163+ Value input = adaptor.getSelf ();
164+ Type indexType = rewriter.getIndexType ();
165+ Value zero = getConstant (rewriter, loc, 0 , indexType);
166+ Value one = getConstant (rewriter, loc, 1 , indexType);
167+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
168+ auto outputType = llvm::cast<RankedTensorType>(getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
169+ unsigned numDims = inputType.getRank ();
170+ assert (numDims >= 2 && " Not enough input dimensions" );
171+ int64_t lastDim = numDims - 1 ;
172+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
173+ Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4
174+
175+ Value tileWidth[3 ], extractOffset[3 ], insertOffset[3 ];
176+
177+ tileWidth[PAD_LEFT] = getConstant (rewriter, loc, padInts[PAD_LEFT], indexType);
178+ tileWidth[PAD_RIGHT] = getConstant (rewriter, loc, padInts[PAD_RIGHT], indexType);
179+ tileWidth[PAD_CENTER] = lastDimSize;
180+
181+ extractOffset[PAD_LEFT] = one;
182+ // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right
183+ // lasDimSize - (tileWidth[PAD_RIGHT] + one)
184+ extractOffset[PAD_RIGHT] = createISub (lastDimSize, createIAdd (tileWidth[PAD_RIGHT], one));
185+ extractOffset[PAD_CENTER] = zero;
186+
187+ insertOffset[PAD_LEFT] = zero;
188+ insertOffset[PAD_RIGHT] = createIAdd (lastDimSize, tileWidth[PAD_LEFT]);
189+ insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
190+
191+
192+ SmallVector<Value> resultShape{inputShape};
193+ // Result's last dimension will have shape lastDimSize + left padding size + right padding size
194+ resultShape[lastDim] = createIAdd (resultShape[lastDim], createIAdd (tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
195+ Value resultTensor = createZeroInitTensor (rewriter, loc, resultShape, inputType.getElementType ());
196+
197+ // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor
198+ // for which the corresponding dimension has a statically known size
199+ auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) {
200+ AffineExpr d = map.getResult (i);
201+ return map.replace (d, size - d - 1 , numDims, 0 ); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3
202+ };
203+
204+ SmallVector<utils::IteratorType> iteratorTypes{numDims, utils::IteratorType::parallel};
205+ auto idMap = AffineMap::getMultiDimIdentityMap (numDims, context);
206+ SmallVector<Value> allOneStrides (numDims, one);
207+
208+ auto addTileToResult = [&](PadLocation padPosition) {
209+ // Create the tile by extracting a slice from the input tensor.
210+ SmallVector<Value> extractShape{inputShape};
211+ extractShape[lastDim] = tileWidth[padPosition];
212+ SmallVector<Value> extractOffsets (numDims, zero);
213+ extractOffsets[lastDim] = extractOffset[padPosition];
214+ Value tile = rewriter.create <tensor::ExtractSliceOp>(
215+ loc, input, extractOffsets, extractShape, allOneStrides);
216+
217+
218+ auto inputMap = AffineMap::getMultiDimIdentityMap (numDims, context);
219+ // Setup the affine map function to resverse the tile along the horizontal for left and right slices
220+ if (padPosition < PAD_CENTER) {
221+ inputMap = reflectDim (inputMap, numDims, lastDim, padInts[padPosition]);
222+ // Take reflected slice as per inputMap
223+ tile = rewriter.create <linalg::GenericOp>(loc, llvm::cast<RankedTensorType>(tile.getType ()), tile,
224+ tile, ArrayRef ({inputMap, idMap}), iteratorTypes,
225+ [](OpBuilder &b, Location nestedLoc, ValueRange args) {
226+ b.create <linalg::YieldOp>(nestedLoc, args[0 ]);
227+ }).getResult (0 );
228+ }
229+ // Insert the tile in the resultTensor
230+ SmallVector<Value> insertOffsets (numDims, zero);
231+ insertOffsets[lastDim] = insertOffset[padPosition];
232+ resultTensor = rewriter.create <tensor::InsertSliceOp>(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
233+ };
234+
235+ if (padInts[PAD_LEFT] > 0 )
236+ addTileToResult (PAD_LEFT);
237+ if (padInts[PAD_RIGHT] > 0 )
238+ addTileToResult (PAD_RIGHT);
239+ addTileToResult (PAD_CENTER);
240+
241+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, outputType, resultTensor);
242+ return success ();
243+ }
244+ };
245+ }
246+
110247namespace {
111248class ConvertAtenFlattenUsingIntsOp
112249 : public OpConversionPattern<AtenFlattenUsingIntsOp> {
@@ -1413,6 +1550,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
14131550 TypeConverter &typeConverter, RewritePatternSet &patterns,
14141551 ConversionTarget &target) {
14151552 MLIRContext *context = patterns.getContext ();
1553+ target.addIllegalOp <AtenReflectionPad1dOp>();
1554+ patterns.add <ConvertAtenReflectionPad1dOp>(typeConverter, context);
14161555 target.addIllegalOp <AtenFlattenUsingIntsOp>();
14171556 patterns.add <ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
14181557 target.addIllegalOp <AtenViewOp>();
0 commit comments