@@ -116,6 +116,83 @@ class ConvertAtenConstantPadNdOp
116
116
117
117
namespace {
118
118
119
+ class ConvertAtenReplicationPad1dOp
120
+ : public OpConversionPattern<AtenReplicationPad1dOp> {
121
+ public:
122
+ using OpConversionPattern::OpConversionPattern;
123
+
124
+ LogicalResult
125
+ matchAndRewrite (AtenReplicationPad1dOp op, OpAdaptor adaptor,
126
+ ConversionPatternRewriter &rewriter) const override {
127
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
128
+ return failure ();
129
+
130
+ Location loc = op.getLoc ();
131
+ Value input = adaptor.getSelf ();
132
+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
133
+ int64_t inputRank = inputType.getRank ();
134
+
135
+ if (inputRank < 2 )
136
+ return rewriter.notifyMatchFailure (op, " input rank must be at least 2" );
137
+
138
+ SmallVector<int64_t > padInts;
139
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
140
+ return rewriter.notifyMatchFailure (
141
+ op, " only support constant int pad ranges" );
142
+
143
+ if (padInts.size () != 2 )
144
+ return rewriter.notifyMatchFailure (
145
+ op, " pad range must have exactly two values" );
146
+
147
+ int64_t leftPad = padInts[0 ];
148
+ int64_t rightPad = padInts[1 ];
149
+
150
+ int64_t dimToPad = inputRank - 1 ;
151
+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
152
+
153
+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
154
+ Value widthSize = inputShape[dimToPad];
155
+ Value widthMinusOne = rewriter.create <arith::SubIOp>(loc, widthSize, one);
156
+
157
+ // Build offset and size arrays for slicing
158
+ SmallVector<OpFoldResult> allOneStrides (inputRank,
159
+ rewriter.getIndexAttr (1 ));
160
+ SmallVector<OpFoldResult> leftOffsets (inputRank, rewriter.getIndexAttr (0 ));
161
+ SmallVector<OpFoldResult> rightOffsets (inputRank, rewriter.getIndexAttr (0 ));
162
+ SmallVector<OpFoldResult> sizes (inputRank, rewriter.getIndexAttr (0 ));
163
+ for (int i = 0 ; i < inputRank; ++i)
164
+ sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr (1 )
165
+ : getAsOpFoldResult (inputShape[i]);
166
+
167
+ rightOffsets[dimToPad] = getAsOpFoldResult (widthMinusOne);
168
+
169
+ // Extract leftmost and rightmost slices
170
+ Value leftSlice = rewriter.create <tensor::ExtractSliceOp>(
171
+ loc, input, leftOffsets, sizes, allOneStrides);
172
+ Value rightSlice = rewriter.create <tensor::ExtractSliceOp>(
173
+ loc, input, rightOffsets, sizes, allOneStrides);
174
+
175
+ // Aggregate slices to concat together
176
+ SmallVector<Value> resultParts;
177
+ resultParts.reserve (leftPad + rightPad + 1 );
178
+
179
+ resultParts.append (leftPad, leftSlice);
180
+ resultParts.push_back (input);
181
+ resultParts.append (rightPad, rightSlice);
182
+
183
+ Value result =
184
+ rewriter.create <tensor::ConcatOp>(loc, dimToPad, resultParts);
185
+ Type resultType = getTypeConverter ()->convertType (op.getType ());
186
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
187
+
188
+ return success ();
189
+ }
190
+ };
191
+
192
+ } // namespace
193
+
194
+ namespace {
195
+
119
196
// Lower aten.replication_pad2d operator into a sequence of
120
197
// tensor.extract_slice and tensor.concat operations.
121
198
@@ -621,6 +698,8 @@ void mlir::torch::torch_to_linalg::
621
698
MLIRContext *context = patterns.getContext ();
622
699
target.addIllegalOp <AtenReplicationPad2dOp>();
623
700
patterns.add <ConvertAtenReplicationPad2dOp>(typeConverter, context);
701
+ target.addIllegalOp <AtenReplicationPad1dOp>();
702
+ patterns.add <ConvertAtenReplicationPad1dOp>(typeConverter, context);
624
703
target.addIllegalOp <AtenConstantPadNdOp>();
625
704
patterns.add <ConvertAtenConstantPadNdOp>(typeConverter, context);
626
705
target.addIllegalOp <AtenZerosOp, AtenOnesOp>();
0 commit comments