Skip to content

Commit 6fe70c7

Browse files
[MLIR][TORCH] Add E2E support for aten.index.Tensor op
This commit adds lowering of `aten.index.Tensor` op Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 0188ca5 commit 6fe70c7

File tree

5 files changed

+170
-24
lines changed

5 files changed

+170
-24
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,3 +1069,21 @@ def forward(self, a, b):
10691069
@register_test_case(module_factory=lambda: ReturnTwoTensorF32I64())
10701070
def ReturnTwoTensorF32I64_basic(module, tu: TestUtils):
10711071
module.forward(tu.rand(2, 3), torch.randint(5, (2, 3)))
1072+
1073+
1074+
class IndexTensorModule(torch.nn.Module):
1075+
def __init__(self):
1076+
super().__init__()
1077+
1078+
@export
1079+
@annotate_args([
1080+
None,
1081+
([-1], torch.float32, True),
1082+
([-1, -1], torch.int64, True),
1083+
])
1084+
def forward(self, x, index):
1085+
return torch.ops.aten.index(x, (index,))
1086+
1087+
@register_test_case(module_factory=lambda: IndexTensorModule())
1088+
def IndexTensorModule_basic(module, tu: TestUtils):
1089+
module.forward(tu.rand(5), torch.randint(4, (2, 3)))

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4231,6 +4231,83 @@ class ConvertAtenArangeStartStepOp
42314231
};
42324232
} // namespace
42334233

4234+
namespace {
4235+
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
4236+
public:
4237+
using OpConversionPattern::OpConversionPattern;
4238+
LogicalResult
4239+
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
4240+
ConversionPatternRewriter &rewriter) const override {
4241+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
4242+
return failure();
4243+
4244+
Location loc = op.getLoc();
4245+
Value input = adaptor.self();
4246+
Value indices = op.indices();
4247+
SmallVector<Value> indicesTuple;
4248+
if (!getListConstructElements(indices, indicesTuple)) {
4249+
return rewriter.notifyMatchFailure(
4250+
op, "unimplemented: the indices list is not from a list construct");
4251+
}
4252+
4253+
SmallVector<Value> indicesVal =
4254+
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
4255+
4256+
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
4257+
RankedTensorType resultType = getTypeConverter()
4258+
->convertType(op->getResult(0).getType())
4259+
.cast<RankedTensorType>();
4260+
Type elementType = resultType.getElementType();
4261+
unsigned inputRank = inputType.getRank();
4262+
unsigned numIndexTensors = indicesTuple.size();
4263+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
4264+
4265+
// Case 1 : When numIndexTensors == 1 and `input` is a 1-d tensor.
4266+
// TODO: generalize the implementation for other cases.
4267+
if (numIndexTensors == 1 && inputRank == 1) {
4268+
if (failed(checkNotNone(rewriter, op, indicesVal[0])))
4269+
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
4270+
unsigned resultRank =
4271+
indicesVal[0].getType().cast<RankedTensorType>().getRank();
4272+
SmallVector<Value> resultShape;
4273+
SmallVector<AffineExpr> indicesExpr, resultExpr;
4274+
SmallVector<StringRef> iteratorTypes;
4275+
for (unsigned i = 0; i < resultRank; i++)
4276+
resultShape.push_back(getDimOp(rewriter, loc, indicesVal[0], i));
4277+
Value initTensor =
4278+
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
4279+
for (unsigned i = 0; i < resultRank; i++) {
4280+
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
4281+
resultExpr.push_back(rewriter.getAffineDimExpr(i));
4282+
iteratorTypes.push_back(getParallelIteratorTypeName());
4283+
}
4284+
auto indexingMaps =
4285+
AffineMap::inferFromExprList({indicesExpr, resultExpr});
4286+
4287+
Value finalRes =
4288+
rewriter
4289+
.create<linalg::GenericOp>(
4290+
loc, initTensor.getType(), ValueRange{indicesVal[0]},
4291+
initTensor,
4292+
/*indexingMaps=*/indexingMaps,
4293+
/*iteratorTypes=*/iteratorTypes,
4294+
[&](OpBuilder &b, Location loc, ValueRange args) {
4295+
Value indexTarget = castIntToIndex(b, loc, args[0]);
4296+
Value extractedElement =
4297+
b.create<tensor::ExtractOp>(loc, input, indexTarget);
4298+
b.create<linalg::YieldOp>(loc, extractedElement);
4299+
})
4300+
.getResult(0);
4301+
4302+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
4303+
return success();
4304+
} else
4305+
return rewriter.notifyMatchFailure(
4306+
op, "unimplemented: support for this set of inputs not present");
4307+
}
4308+
};
4309+
} // namespace
4310+
42344311
// -----------------------------------------------------------------------------
42354312
// The pass
42364313
// -----------------------------------------------------------------------------
@@ -4348,6 +4425,8 @@ class ConvertTorchToLinalg
43484425
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
43494426
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter, context);
43504427
target.addIllegalOp<AtenArangeStartStepOp>();
4428+
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
4429+
target.addIllegalOp<AtenIndexTensorOp>();
43514430

43524431
if (failed(applyPartialConversion(getOperation(), target,
43534432
std::move(patterns))))

lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,34 @@ class ConvertToImmutableTensors : public RewritePattern {
3939
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
4040
opOperand.get()));
4141
} else if (auto listType = operandType.dyn_cast<ListType>()) {
42-
if (!listType.getContainedType().isa<NonValueTensorType>())
42+
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
43+
listType.getContainedType().isa<OptionalType>()))
4344
continue;
4445

4546
// Construct a new list whose elements are value tensors copied from
46-
// the none value tensors of the original list.
47+
// the non-value tensors of the original list.
4748
auto listConstruct =
4849
opOperand.get().getDefiningOp<PrimListConstructOp>();
4950
if (!listConstruct) {
5051
rewriter.cancelRootUpdate(op);
51-
return rewriter.notifyMatchFailure(op,
52-
"unimplemented: list of non vtensor type not constructed "
53-
"from list construct");
52+
return rewriter.notifyMatchFailure(
53+
op, "unimplemented: list of non vtensor type not constructed "
54+
"from list construct");
5455
}
5556

5657
if (listConstruct.elements().empty())
5758
continue;
59+
60+
// TODO: Handle optional type in list type.
61+
if (listType.getContainedType().isa<OptionalType>()) {
62+
if (!llvm::all_of(listConstruct.elements(), [](Value val) {
63+
return val.getType().isa<NonValueTensorType>();
64+
}))
65+
return rewriter.notifyMatchFailure(
66+
op, "unimplemented: list containing optional type is not "
67+
"handled.");
68+
}
69+
5870
auto newListElements = llvm::to_vector<4>(llvm::map_range(
5971
listConstruct.elements(), [&](Value tensor) -> Value {
6072
return rewriter.create<CopyToValueTensorOp>(op->getLoc(), tensor);
@@ -74,8 +86,9 @@ class ConvertToImmutableTensors : public RewritePattern {
7486
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
7587
if (!derefine) {
7688
rewriter.cancelRootUpdate(op);
77-
return rewriter.notifyMatchFailure(op,
78-
"unimplemented: optional of non vtensor type not from derefine");
89+
return rewriter.notifyMatchFailure(
90+
op, "unimplemented: optional of non vtensor type not from "
91+
"derefine");
7992
}
8093

8194
if (!derefine.operand().getType().isa<NonValueTensorType>())

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -282,23 +282,6 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
282282
knowledge.dtype = input.dtype;
283283
return getLatticeElement(op->getResult(0)).join(knowledge);
284284
}
285-
// `torch.aten.index.Tensor` return tensors elements selected by the index
286-
// tensors. Each index tensor in the list corresponds to each dim in the
287-
// input tensor.
288-
// Same number of dims but unknown size along each dim. Same dtype as the
289-
// input.
290-
if (auto indexTensor = dyn_cast<AtenIndexTensorOp>(op)) {
291-
auto input = operands[0]->getValue();
292-
auto knowledge =
293-
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
294-
if (input.hasSizes) {
295-
knowledge.hasSizes = true;
296-
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
297-
}
298-
knowledge.dtype = input.dtype;
299-
return getLatticeElement(op->getResult(0)).join(knowledge);
300-
}
301-
302285
if (auto mm = llvm::dyn_cast<AtenMmOp>(op)) {
303286
return visitAtenMmOp(mm, operands);
304287
} else if (auto addmm = llvm::dyn_cast<AtenAddmmOp>(op)) {
@@ -492,6 +475,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
492475
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
493476
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
494477
return visitAtenConstantPadNdOp(constantPadNdOp, operands);
478+
} else if (auto indexTensorOp = dyn_cast<AtenIndexTensorOp>(op)) {
479+
return visitAtenIndexTensorOp(indexTensorOp, operands);
495480
}
496481

497482
// Otherwise, this is an unknown operation. Just mark all results as
@@ -646,6 +631,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
646631
ChangeResult visitAtenNativeLayerNormOp(
647632
AtenNativeLayerNormOp op,
648633
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
634+
ChangeResult
635+
visitAtenIndexTensorOp(AtenIndexTensorOp op,
636+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
649637
};
650638
} // namespace
651639

@@ -1848,6 +1836,36 @@ ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp(
18481836

18491837
return resultLattice;
18501838
}
1839+
1840+
// `torch.aten.index.Tensor` return tensors elements selected by the index
1841+
// tensors. Each index tensor in the list corresponds to each dim in the
1842+
// input tensor.
1843+
// Same number of dims but unknown size along each dim. Same dtype as the
1844+
// input.
1845+
ChangeResult TypeAnalyzer::visitAtenIndexTensorOp(
1846+
AtenIndexTensorOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1847+
auto input = operands[0]->getValue();
1848+
auto indicesList = op.indices();
1849+
auto knowledge =
1850+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1851+
auto listConstruct = indicesList.getDefiningOp<PrimListConstructOp>();
1852+
if (!listConstruct)
1853+
return getLatticeElement(op->getResult(0)).join(knowledge);
1854+
1855+
auto indices = llvm::to_vector(
1856+
llvm::map_range(listConstruct.elements(), [&](Value v) -> ValueKnowledge {
1857+
return getLatticeElement(v).getValue();
1858+
}));
1859+
1860+
knowledge.dtype = input.dtype;
1861+
// Case: If the input is a 1-d tensor and indices list size is equal
1862+
// to 1.
1863+
if (input.sizes.size() == 1 && indices.size() == 1 && indices[0].hasSizes) {
1864+
knowledge.hasSizes = true;
1865+
knowledge.sizes = indices[0].sizes;
1866+
}
1867+
return getLatticeElement(op->getResult(0)).join(knowledge);
1868+
}
18511869
// -----------------------------------------------------------------------------
18521870
// Transforms.
18531871
// -----------------------------------------------------------------------------

test/Dialect/Torch/reduce-op-variants.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,21 @@ func @torch.tensor.literal() -> !torch.tensor {
112112
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
113113
return %0 : !torch.tensor
114114
}
115+
116+
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional_list(
117+
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
118+
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
119+
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDICES]] :
120+
// CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
121+
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
122+
// CHECK: %[[INDICES_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDICES]] : !torch.vtensor<[2,3],si64>
123+
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<!torch.vtensor<[2,3],si64>>
124+
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<!torch.vtensor<[2,3],si64>> -> !torch.vtensor
125+
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
126+
// CHECK: return %[[RET]] : !torch.tensor
127+
// CHECK: }
128+
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
129+
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
130+
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>> -> !torch.tensor
131+
return %ret : !torch.tensor
132+
}

0 commit comments

Comments
 (0)