Skip to content

Commit 9adad9b

Browse files
Add support for reflection_pad1d (#2706)
Adds a lowering to Linalg for reflection_pad1d. Based on ideas/code from draft PR #2693. --------- Co-authored-by: Kumar Deepak <[email protected]>
1 parent 6660a26 commit 9adad9b

File tree

6 files changed

+313
-0
lines changed

6 files changed

+313
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7869,6 +7869,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
78697869
}];
78707870
}
78717871

7872+
def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
7873+
AllowsTypeRefinement,
7874+
HasValueSemantics,
7875+
ReadOnly
7876+
]> {
7877+
let summary = "Generated op for `aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)`";
7878+
let arguments = (ins
7879+
AnyTorchTensorType:$self,
7880+
AnyTorchListOfTorchIntType:$padding
7881+
);
7882+
let results = (outs
7883+
AnyTorchTensorType:$result
7884+
);
7885+
let hasCustomAssemblyFormat = 1;
7886+
let extraClassDefinition = [{
7887+
ParseResult AtenReflectionPad1dOp::parse(OpAsmParser &parser, OperationState &result) {
7888+
return parseDefaultTorchOp(parser, result, 2, 1);
7889+
}
7890+
void AtenReflectionPad1dOp::print(OpAsmPrinter &printer) {
7891+
printDefaultTorchOp(printer, *this, 2, 1);
7892+
}
7893+
}];
7894+
}
7895+
78727896
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
78737897
AllowsTypeRefinement,
78747898
HasValueSemantics,

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,143 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
107107
return success();
108108
}
109109

110+
// Example:
111+
// input = tensor([[[0., 1., 2., 3.],
112+
// [4., 5., 6., 7.]]])
113+
// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1
114+
// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
115+
// [7., 6., 5., 4., 5., 6., 7., 6.]]])
116+
// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension
117+
// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension.
118+
// The last dimension of the result tensor should be last dimension of input tensor +
119+
// left padding size + right padding size. INitialize result tensor to all zeros
120+
// b) Setup affine map to take slice from input tensor of size left padding starting from
121+
// second column onwards as first column is reflection boundary
122+
// c) Reflect the affine map to have resultant slice reflected
123+
// d) Take the slice and write from begining in result tensor
124+
// e) write the original tensor next into result tensor
125+
// f) Setup affine map to take slice from input tensor of right padding size ending
126+
// at second last column as last column is reflection boundary for right padding
127+
// g) Reflect the affine map to have resultant slice reflected
128+
// h) Take the slice and write from left padding size + orignal tensor last dim size
129+
// into result tensor
130+
// Uses the ideas/code used for AtenReflectionPad2dOp
131+
namespace {
132+
class ConvertAtenReflectionPad1dOp
133+
: public OpConversionPattern<AtenReflectionPad1dOp> {
134+
public:
135+
using OpConversionPattern::OpConversionPattern;
136+
LogicalResult
137+
matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor,
138+
ConversionPatternRewriter &rewriter) const override {
139+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
140+
return failure();
141+
142+
SmallVector<int64_t> padInts;
143+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
144+
return rewriter.notifyMatchFailure(
145+
op, "only constant int padding range is supported");
146+
147+
MLIRContext *context = rewriter.getContext();
148+
Location loc = op.getLoc();
149+
150+
// Lambda Unitility Functions
151+
// Create an Integer expression of x + y
152+
auto createIAdd = [&](Value x, Value y) {
153+
return rewriter.create<arith::AddIOp>(loc, x, y);
154+
};
155+
156+
// Create an integer expression of x - y
157+
auto createISub = [&](Value x, Value y) {
158+
return rewriter.create<arith::SubIOp>(loc, x, y);
159+
};
160+
161+
enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2};
162+
163+
Value input = adaptor.getSelf();
164+
Type indexType = rewriter.getIndexType();
165+
Value zero = getConstant(rewriter, loc, 0, indexType);
166+
Value one = getConstant(rewriter, loc, 1, indexType);
167+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
168+
auto outputType = llvm::cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
169+
unsigned numDims = inputType.getRank();
170+
assert(numDims >= 2 && "Not enough input dimensions");
171+
int64_t lastDim = numDims - 1;
172+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
173+
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4
174+
175+
Value tileWidth[3], extractOffset[3], insertOffset[3];
176+
177+
tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
178+
tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
179+
tileWidth[PAD_CENTER] = lastDimSize;
180+
181+
extractOffset[PAD_LEFT] = one;
182+
// for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right
183+
// lasDimSize - (tileWidth[PAD_RIGHT] + one)
184+
extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
185+
extractOffset[PAD_CENTER] = zero;
186+
187+
insertOffset[PAD_LEFT] = zero;
188+
insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]);
189+
insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
190+
191+
192+
SmallVector<Value> resultShape{inputShape};
193+
// Result's last dimension will have shape lastDimSize + left padding size + right padding size
194+
resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
195+
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType());
196+
197+
// Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor
198+
// for which the corresponding dimension has a statically known size
199+
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) {
200+
AffineExpr d = map.getResult(i);
201+
return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3
202+
};
203+
204+
SmallVector<utils::IteratorType> iteratorTypes{numDims, utils::IteratorType::parallel};
205+
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
206+
SmallVector<Value> allOneStrides(numDims, one);
207+
208+
auto addTileToResult = [&](PadLocation padPosition) {
209+
// Create the tile by extracting a slice from the input tensor.
210+
SmallVector<Value> extractShape{inputShape};
211+
extractShape[lastDim] = tileWidth[padPosition];
212+
SmallVector<Value> extractOffsets(numDims, zero);
213+
extractOffsets[lastDim] = extractOffset[padPosition];
214+
Value tile = rewriter.create<tensor::ExtractSliceOp>(
215+
loc, input, extractOffsets, extractShape, allOneStrides);
216+
217+
218+
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
219+
// Setup the affine map function to resverse the tile along the horizontal for left and right slices
220+
if(padPosition < PAD_CENTER) {
221+
inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]);
222+
// Take reflected slice as per inputMap
223+
tile = rewriter.create<linalg::GenericOp>(loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
224+
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
225+
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
226+
b.create<linalg::YieldOp>(nestedLoc, args[0]);
227+
}).getResult(0);
228+
}
229+
// Insert the tile in the resultTensor
230+
SmallVector<Value> insertOffsets(numDims, zero);
231+
insertOffsets[lastDim] = insertOffset[padPosition];
232+
resultTensor = rewriter.create<tensor::InsertSliceOp>(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
233+
};
234+
235+
if(padInts[PAD_LEFT] > 0)
236+
addTileToResult(PAD_LEFT);
237+
if(padInts[PAD_RIGHT] > 0)
238+
addTileToResult(PAD_RIGHT);
239+
addTileToResult(PAD_CENTER);
240+
241+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
242+
return success();
243+
}
244+
};
245+
}
246+
110247
namespace {
111248
class ConvertAtenFlattenUsingIntsOp
112249
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
@@ -1413,6 +1550,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
14131550
TypeConverter &typeConverter, RewritePatternSet &patterns,
14141551
ConversionTarget &target) {
14151552
MLIRContext *context = patterns.getContext();
1553+
target.addIllegalOp<AtenReflectionPad1dOp>();
1554+
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
14161555
target.addIllegalOp<AtenFlattenUsingIntsOp>();
14171556
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
14181557
target.addIllegalOp<AtenViewOp>();

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8331,6 +8331,41 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
83318331
" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
83328332
" return %0 : !torch.list<int>\n"
83338333
" }\n"
8334+
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
8335+
" %false = torch.constant.bool false\n"
8336+
" %int-1 = torch.constant.int -1\n"
8337+
" %none = torch.constant.none\n"
8338+
" %str = torch.constant.str \"AssertionError: \"\n"
8339+
" %int2 = torch.constant.int 2\n"
8340+
" %int1 = torch.constant.int 1\n"
8341+
" %int0 = torch.constant.int 0\n"
8342+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8343+
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
8344+
" torch.prim.If %1 -> () {\n"
8345+
" torch.prim.If.yield\n"
8346+
" } else {\n"
8347+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8348+
" torch.prim.If.yield\n"
8349+
" }\n"
8350+
" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
8351+
" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
8352+
" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
8353+
" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n"
8354+
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
8355+
" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n"
8356+
" torch.prim.If.yield %8 : !torch.bool\n"
8357+
" } else {\n"
8358+
" torch.prim.If.yield %false : !torch.bool\n"
8359+
" }\n"
8360+
" torch.prim.If %6 -> () {\n"
8361+
" torch.prim.If.yield\n"
8362+
" } else {\n"
8363+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8364+
" torch.prim.If.yield\n"
8365+
" }\n"
8366+
" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
8367+
" return %7 : !torch.list<int>\n"
8368+
" }\n"
83348369
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
83358370
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
83368371
" return %0 : !torch.list<int>\n"
@@ -8952,6 +8987,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
89528987
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
89538988
" return %0#1 : !torch.int\n"
89548989
" }\n"
8990+
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
8991+
" %none = torch.constant.none\n"
8992+
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
8993+
" %int2 = torch.constant.int 2\n"
8994+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8995+
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
8996+
" %2 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n"
8997+
" torch.prim.If %2 -> () {\n"
8998+
" torch.prim.If.yield\n"
8999+
" } else {\n"
9000+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9001+
" torch.prim.If.yield\n"
9002+
" }\n"
9003+
" return %0#1 : !torch.int\n"
9004+
" }\n"
89559005
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
89569006
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
89579007
" return %0#1 : !torch.int\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,21 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float
12711271
def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
12721272
return pad_shape_fn(self, pad)
12731273

1274+
#Padding size must be smaller than the size of the last dimension
1275+
@check_shape_function([ErrorInvocation(TensorOfShape(1, 2, 4), padding=[4,1]),
1276+
Invocation(TensorOfShape(1, 2, 4), padding=[3,3]),
1277+
ErrorInvocation(TensorOfShape(1, 2, 4), padding=[1,4]),
1278+
ErrorInvocation(TensorOfShape(1, 4), padding=[4,1]),
1279+
Invocation(TensorOfShape(1, 4), padding=[3,3]),
1280+
ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])])
1281+
def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]:
1282+
assert len(self) >= 2
1283+
hdim = self[-1]
1284+
padding_left = padding[0]
1285+
padding_right = padding[1]
1286+
assert padding_left < hdim and padding_right < hdim
1287+
return pad_shape_fn(self, padding)
1288+
12741289
# TODO: upstream this
12751290
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
12761291
assert len(indices) <= len(self), "More indices than dimensions to index"
@@ -1804,6 +1819,18 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i
18041819
self_rank, self_dtype = self_rank_dtype
18051820
return self_dtype
18061821

1822+
1823+
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
1824+
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
1825+
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
1826+
Invocation(TensorOfShape(2, 3, 4), padding=[2,1]),
1827+
Invocation(TensorOfShape(5, 5, 4), padding=[1,2]),
1828+
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1])])
1829+
def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
1830+
self_rank, self_dtype = self_rank_dtype
1831+
assert len(padding) == 2, 'padding size expected to be 2'
1832+
return self_dtype
1833+
18071834
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
18081835
def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int:
18091836
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def emit_with_mutating_variants(key, **kwargs):
541541

542542
# Misc tensor ops.
543543
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
544+
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
544545
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
545546
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
546547
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,80 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils):
552552

553553

554554
# ==============================================================================
555+
class ReflectionPad1dModule3dInput(torch.nn.Module):
555556

557+
def __init__(self):
558+
super().__init__()
559+
560+
@export
561+
@annotate_args([
562+
None,
563+
([1, 2, 4], torch.float32, True),
564+
])
565+
def forward(self, x):
566+
return torch.ops.aten.reflection_pad1d(x, (3,1))
567+
568+
569+
@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput())
570+
def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils):
571+
module.forward(tu.rand(1,2,4))
572+
573+
574+
class ReflectionPad1dModule2dInput(torch.nn.Module):
575+
576+
def __init__(self):
577+
super().__init__()
578+
579+
@export
580+
@annotate_args([
581+
None,
582+
([2, 4], torch.float32, True),
583+
])
584+
def forward(self, x):
585+
return torch.ops.aten.reflection_pad1d(x, (3,2))
586+
587+
588+
@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput())
589+
def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils):
590+
module.forward(tu.rand(2,4))
591+
592+
class ReflectionPad1dModule3dInputLeft(torch.nn.Module):
556593

594+
def __init__(self):
595+
super().__init__()
596+
597+
@export
598+
@annotate_args([
599+
None,
600+
([1, 4, 5], torch.float32, True),
601+
])
602+
def forward(self, x):
603+
return torch.ops.aten.reflection_pad1d(x, (2,0))
604+
605+
606+
@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft())
607+
def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils):
608+
module.forward(tu.rand(1,4,5))
609+
610+
class ReflectionPad1dModule2dInputRight(torch.nn.Module):
611+
612+
def __init__(self):
613+
super().__init__()
614+
615+
@export
616+
@annotate_args([
617+
None,
618+
([3, 6], torch.float32, True),
619+
])
620+
def forward(self, x):
621+
return torch.ops.aten.reflection_pad1d(x, (0,3))
622+
623+
624+
@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight())
625+
def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils):
626+
module.forward(tu.rand(3,6))
627+
628+
# ==============================================================================
557629
class TransposeIntModule(torch.nn.Module):
558630

559631
def __init__(self):

0 commit comments

Comments
 (0)