Skip to content

Commit 8130354

Browse files
vivekkhandelwal1Prashant Kumar
authored andcommitted
[MLIR][TORCH] Add E2E support for aten.index_select op
This commit adds lowering of `aten.index_select` op. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 0a0a1b4 commit 8130354

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class IndexSelectSingleIdxModule(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
@export
20+
@annotate_args([
21+
None,
22+
([4, 5, 6], torch.float32, True),
23+
([1], torch.int64, True),
24+
])
25+
26+
def forward(self, input, indices):
27+
return torch.index_select(input, 1, indices)
28+
29+
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
30+
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
31+
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
32+
33+
34+
class IndexSelectTwoIdxModule(torch.nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
38+
@export
39+
@annotate_args([
40+
None,
41+
([4, 5, 6], torch.float32, True),
42+
([2], torch.int64, True),
43+
])
44+
45+
def forward(self, input, indices):
46+
return torch.index_select(input, 2, indices)
47+
48+
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
49+
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
50+
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
51+
52+
53+
class IndexSelectWholeDimensionModule(torch.nn.Module):
54+
def __init__(self):
55+
super().__init__()
56+
57+
@export
58+
@annotate_args([
59+
None,
60+
([4, 5, 6], torch.float32, True),
61+
([4], torch.int64, True),
62+
])
63+
64+
def forward(self, input, indices):
65+
return torch.index_select(input, 0, indices)
66+
67+
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
68+
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
69+
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
70+
71+
72+
class IndexSelectWholeTensorModule(torch.nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
76+
@export
77+
@annotate_args([
78+
None,
79+
([3], torch.float32, True),
80+
([3], torch.int64, True),
81+
])
82+
83+
def forward(self, input, indices):
84+
return torch.index_select(input, 0, indices)
85+
86+
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
87+
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
88+
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
89+
90+
91+
class IndexSelectDynamicModule(torch.nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
95+
@export
96+
@annotate_args([
97+
None,
98+
([-1, -1, -1], torch.float32, True),
99+
([-1], torch.int64, True),
100+
])
101+
102+
def forward(self, input, indices):
103+
return torch.index_select(input, 2, indices)
104+
105+
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
106+
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
107+
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
108+
109+
110+
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
111+
def __init__(self):
112+
super().__init__()
113+
114+
@export
115+
@annotate_args([
116+
None,
117+
([-1, -1, -1], torch.float32, True),
118+
([2], torch.int64, True),
119+
])
120+
121+
def forward(self, input, indices):
122+
return torch.index_select(input, 2, indices)
123+
124+
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
125+
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
126+
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
127+
128+
129+
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
130+
def __init__(self):
131+
super().__init__()
132+
133+
@export
134+
@annotate_args([
135+
None,
136+
([4, 5, 6], torch.float32, True),
137+
([-1], torch.int64, True),
138+
])
139+
140+
def forward(self, input, indices):
141+
return torch.index_select(input, 1, indices)
142+
143+
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
144+
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
145+
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from . import squeeze
4545
from . import slice_like
4646
from . import nll_loss
47+
from . import index_select
4748

4849
def _get_argparse():
4950
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3439,6 +3439,82 @@ class ConvertAtenNumelOp : public OpConversionPattern<AtenNumelOp> {
34393439
};
34403440
} // namespace
34413441

3442+
namespace {
3443+
// Let's say we have an input tensor: initialized with some random values of
3444+
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
3445+
// integer argument dim = 1. The size of the output tensor will be [4, 2, 6].
3446+
// The approach is as follows:
3447+
//
3448+
// for i in range(input.size[0])
3449+
// for j in range(index.size[0])
3450+
// for k in range(input.size[2])
3451+
// indexValue = index[j]
3452+
// output[i,j,k] = input[i,indexValue,k]
3453+
3454+
class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
3455+
public:
3456+
using OpConversionPattern::OpConversionPattern;
3457+
LogicalResult
3458+
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
3459+
ConversionPatternRewriter &rewriter) const override {
3460+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
3461+
return failure();
3462+
3463+
Location loc = op.getLoc();
3464+
Value input = adaptor.self();
3465+
Value indices = adaptor.index();
3466+
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
3467+
RankedTensorType resultType = getTypeConverter()
3468+
->convertType(op->getResult(0).getType())
3469+
.cast<RankedTensorType>();
3470+
Type elementType = resultType.getElementType();
3471+
unsigned inputRank = inputType.getRank();
3472+
3473+
int64_t dimInt;
3474+
if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt)))
3475+
return op->emitError("unimplemented: dim is not constant");
3476+
3477+
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
3478+
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
3479+
Value initTensor =
3480+
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
3481+
3482+
SmallVector<AffineExpr> resultExpr;
3483+
AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt);
3484+
SmallVector<StringRef> iteratorTypes;
3485+
3486+
for (unsigned i = 0; i < inputRank; i++) {
3487+
resultExpr.push_back(rewriter.getAffineDimExpr(i));
3488+
iteratorTypes.push_back(getParallelIteratorTypeName());
3489+
}
3490+
3491+
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
3492+
3493+
Value finalRes =
3494+
rewriter
3495+
.create<linalg::GenericOp>(
3496+
loc, initTensor.getType(), ValueRange{indices}, initTensor,
3497+
/*indexingMaps=*/indexingMaps,
3498+
/*iteratorTypes=*/iteratorTypes,
3499+
[&](OpBuilder &b, Location loc, ValueRange args) {
3500+
Value index = rewriter.create<arith::IndexCastOp>(
3501+
loc, rewriter.getIndexType(), args[0]);
3502+
SmallVector<Value> indexTarget;
3503+
for (unsigned i = 0; i < inputRank; i++)
3504+
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
3505+
indexTarget[dimInt] = index;
3506+
Value extractedElement =
3507+
b.create<tensor::ExtractOp>(loc, input, indexTarget);
3508+
b.create<linalg::YieldOp>(loc, extractedElement);
3509+
})
3510+
.getResult(0);
3511+
3512+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
3513+
return success();
3514+
}
3515+
};
3516+
} // namespace
3517+
34423518
// -----------------------------------------------------------------------------
34433519
// The pass
34443520
// -----------------------------------------------------------------------------
@@ -3539,6 +3615,8 @@ class ConvertTorchToLinalg
35393615
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
35403616
target.addIllegalOp<AtenNllLossForwardOp>();
35413617
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
3618+
target.addIllegalOp<AtenIndexSelectOp>();
3619+
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
35423620

35433621
if (failed(applyPartialConversion(getOperation(), target,
35443622
std::move(patterns))))

0 commit comments

Comments
 (0)