Skip to content

Commit 4486de5

Browse files
[MLIR][TORCH] Add E2E support for torch.arange op
This commit adds lowering of `aten.arange.start_step` op. This commit decomposes `aten.arange` and `aten.arange.start` into `aten.arange.start_step` op. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent a83004c commit 4486de5

File tree

8 files changed

+465
-6
lines changed

8 files changed

+465
-6
lines changed

e2e_testing/torchscript/arange.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class ArangeIntModule(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
@export
20+
@annotate_args([
21+
None,
22+
])
23+
24+
def forward(self):
25+
return torch.arange(5)
26+
27+
@register_test_case(module_factory=lambda: ArangeIntModule())
28+
def ArangeIntModule_basic(module, tu: TestUtils):
29+
module.forward()
30+
31+
32+
class ArangeFloatModule(torch.nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
36+
@export
37+
@annotate_args([
38+
None,
39+
])
40+
41+
def forward(self):
42+
return torch.arange(5.0)
43+
44+
@register_test_case(module_factory=lambda: ArangeFloatModule())
45+
def ArangeFloatModule_basic(module, tu: TestUtils):
46+
module.forward()
47+
48+
49+
class ArangeZeroElementOutputModule(torch.nn.Module):
50+
def __init__(self):
51+
super().__init__()
52+
53+
@export
54+
@annotate_args([
55+
None,
56+
])
57+
58+
def forward(self):
59+
return torch.arange(0)
60+
61+
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
62+
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
63+
module.forward()
64+
65+
66+
class ArangeStartIntModule(torch.nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
70+
@export
71+
@annotate_args([
72+
None,
73+
])
74+
75+
def forward(self):
76+
return torch.arange(0, 5)
77+
78+
@register_test_case(module_factory=lambda: ArangeStartIntModule())
79+
def ArangeStartIntModule_basic(module, tu: TestUtils):
80+
module.forward()
81+
82+
83+
class ArangeStartFloatModule(torch.nn.Module):
84+
def __init__(self):
85+
super().__init__()
86+
87+
@export
88+
@annotate_args([
89+
None,
90+
])
91+
92+
def forward(self):
93+
return torch.arange(0.0, 5.0)
94+
95+
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
96+
def ArangeStartFloatModule_basic(module, tu: TestUtils):
97+
module.forward()
98+
99+
100+
class ArangeNegativeStartIntModule(torch.nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
104+
@export
105+
@annotate_args([
106+
None,
107+
])
108+
109+
def forward(self):
110+
return torch.arange(-10, 5)
111+
112+
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
113+
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
114+
module.forward()
115+
116+
117+
class ArangeNegativeStartFloatModule(torch.nn.Module):
118+
def __init__(self):
119+
super().__init__()
120+
121+
@export
122+
@annotate_args([
123+
None,
124+
])
125+
126+
def forward(self):
127+
return torch.arange(-1.4, 5.7)
128+
129+
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
130+
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
131+
module.forward()
132+
133+
134+
class ArangeStartStepIntModule(torch.nn.Module):
135+
def __init__(self):
136+
super().__init__()
137+
138+
@export
139+
@annotate_args([
140+
None,
141+
])
142+
143+
def forward(self):
144+
return torch.arange(0, 5, 1)
145+
146+
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
147+
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
148+
module.forward()
149+
150+
151+
class ArangeStartStepFloatModule(torch.nn.Module):
152+
def __init__(self):
153+
super().__init__()
154+
155+
@export
156+
@annotate_args([
157+
None,
158+
])
159+
160+
def forward(self):
161+
return torch.arange(-1, 5, 1.3)
162+
163+
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
164+
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
165+
module.forward()
166+
167+
168+
class ArangeStartNegativeStepIntModule(torch.nn.Module):
169+
def __init__(self):
170+
super().__init__()
171+
172+
@export
173+
@annotate_args([
174+
None,
175+
])
176+
177+
def forward(self):
178+
return torch.arange(10, 1, -2)
179+
180+
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
181+
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
182+
module.forward()
183+
184+
185+
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
186+
def __init__(self):
187+
super().__init__()
188+
189+
@export
190+
@annotate_args([
191+
None,
192+
])
193+
194+
def forward(self):
195+
return torch.arange(-1, -15, -3.4)
196+
197+
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
198+
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
199+
module.forward()
200+
201+
202+
class ArangeDtypeFloatModule(torch.nn.Module):
203+
def __init__(self):
204+
super().__init__()
205+
206+
@export
207+
@annotate_args([
208+
None,
209+
])
210+
211+
def forward(self):
212+
return torch.arange(-1, 15, dtype=torch.float32)
213+
214+
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
215+
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
216+
module.forward()
217+
218+
219+
class ArangeDtypeIntModule(torch.nn.Module):
220+
def __init__(self):
221+
super().__init__()
222+
223+
@export
224+
@annotate_args([
225+
None,
226+
])
227+
228+
def forward(self):
229+
return torch.arange(0.2, 5.0, dtype=torch.int64)
230+
231+
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
232+
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
233+
module.forward()
234+
235+
236+
class ArangeFalsePinMemoryModule(torch.nn.Module):
237+
def __init__(self):
238+
super().__init__()
239+
240+
@export
241+
@annotate_args([
242+
None,
243+
])
244+
245+
def forward(self):
246+
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
247+
248+
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
249+
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
250+
module.forward()

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import slice_like
4646
from . import nll_loss
4747
from . import index_select
48+
from . import arange
4849

4950
def _get_argparse():
5051
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,6 +2028,26 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
20282028
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
20292029
}
20302030

2031+
def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
2032+
AllowsTypeRefinement,
2033+
HasValueSemantics
2034+
]> {
2035+
let summary = "Generated op for `aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
2036+
let arguments = (ins
2037+
AnyTorchScalarType:$start,
2038+
AnyTorchScalarType:$end,
2039+
AnyTorchScalarType:$step,
2040+
TorchOptionalIntType:$dtype,
2041+
TorchOptionalIntType:$layout,
2042+
TorchOptionalDeviceType:$device,
2043+
TorchOptionalBoolType:$pin_memory
2044+
);
2045+
let results = (outs
2046+
AnyTorchTensorType:$result
2047+
);
2048+
let assemblyFormat = "$start `,` $end `,` $step `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($step) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
2049+
}
2050+
20312051
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
20322052
AllowsTypeRefinement,
20332053
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4022,6 +4022,99 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
40224022
};
40234023
} // namespace
40244024

4025+
namespace {
4026+
// Let's say the result of the `aten.arange.start_step` is `output` which is a
4027+
// 1-d output tensor. The approach used for generating the output tensor is as
4028+
// follows:
4029+
// for i in range(ceil((end-start)/step))
4030+
// output[i] = start + (i * step)
4031+
class ConvertAtenArangeStartStepOp
4032+
: public OpConversionPattern<AtenArangeStartStepOp> {
4033+
public:
4034+
using OpConversionPattern::OpConversionPattern;
4035+
LogicalResult
4036+
matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor,
4037+
ConversionPatternRewriter &rewriter) const override {
4038+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
4039+
return failure();
4040+
4041+
// TODO: Add support for layout, pin_memory features.
4042+
// Only `none` layout is supported.
4043+
if (!op.layout().getType().isa<Torch::NoneType>())
4044+
return rewriter.notifyMatchFailure(
4045+
op, "unimplemented: only default layout is supported");
4046+
4047+
// The pin_memory should be either `False` or `none`.
4048+
bool pinMemory;
4049+
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
4050+
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
4051+
pinMemory)) {
4052+
return rewriter.notifyMatchFailure(
4053+
op, "unimplemented: pin_memory must be either None or false");
4054+
}
4055+
4056+
Location loc = op.getLoc();
4057+
TypeConverter *typeConverter = this->getTypeConverter();
4058+
RankedTensorType resultType =
4059+
typeConverter->convertType(op->getResult(0).getType())
4060+
.cast<RankedTensorType>();
4061+
Type dtype = resultType.getElementType();
4062+
Value start = convertScalarToDtype(rewriter, loc, adaptor.start(), dtype);
4063+
Value end = convertScalarToDtype(rewriter, loc, adaptor.end(), dtype);
4064+
Value step = convertScalarToDtype(rewriter, loc, adaptor.step(), dtype);
4065+
4066+
// The result will always be a 1-d tensor.
4067+
// The size of the result is calculated as follows:
4068+
// ceil((end - start)/step)
4069+
Value resultShape;
4070+
if (dtype.isa<mlir::IntegerType>()) {
4071+
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
4072+
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
4073+
} else {
4074+
Value subOut = rewriter.create<arith::SubFOp>(loc, end, start);
4075+
Value divOut = rewriter.create<arith::DivFOp>(loc, subOut, step);
4076+
Value ceilOut = rewriter.create<math::CeilOp>(loc, divOut);
4077+
resultShape =
4078+
rewriter.create<arith::FPToUIOp>(loc, rewriter.getI64Type(), ceilOut);
4079+
}
4080+
resultShape = castIntToIndex(rewriter, loc, resultShape);
4081+
4082+
Value resultTensor =
4083+
rewriter.create<linalg::InitTensorOp>(loc, resultShape, dtype);
4084+
4085+
StringRef iteratorType = getParallelIteratorTypeName();
4086+
AffineMap indexingMap =
4087+
AffineMap::getMultiDimIdentityMap(1, op->getContext());
4088+
4089+
Value finalRes =
4090+
rewriter
4091+
.create<linalg::GenericOp>(
4092+
loc, /*resultTensorTypes=*/resultTensor.getType(),
4093+
/*inputs=*/ValueRange({}),
4094+
/*outputs=*/resultTensor,
4095+
/*indexingMaps=*/indexingMap,
4096+
/*iteratorTypes=*/iteratorType,
4097+
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
4098+
Value index = b.create<linalg::IndexOp>(loc, 0);
4099+
index = castIndexToInt(b, loc, index);
4100+
index = convertScalarToDtype(b, loc, index, dtype);
4101+
Value mulOut, result;
4102+
if (dtype.isa<mlir::FloatType>()) {
4103+
mulOut = b.create<arith::MulFOp>(loc, step, index);
4104+
result = b.create<arith::AddFOp>(loc, start, mulOut);
4105+
} else {
4106+
mulOut = b.create<arith::MulIOp>(loc, step, index);
4107+
result = b.create<arith::AddIOp>(loc, start, mulOut);
4108+
}
4109+
b.create<linalg::YieldOp>(loc, result);
4110+
})
4111+
.getResult(0);
4112+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
4113+
return success();
4114+
}
4115+
};
4116+
} // namespace
4117+
40254118
// -----------------------------------------------------------------------------
40264119
// The pass
40274120
// -----------------------------------------------------------------------------
@@ -4134,6 +4227,8 @@ class ConvertTorchToLinalg
41344227
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
41354228
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
41364229
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
4230+
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter, context);
4231+
target.addIllegalOp<AtenArangeStartStepOp>();
41374232

41384233
if (failed(applyPartialConversion(getOperation(), target,
41394234
std::move(patterns))))

0 commit comments

Comments
 (0)