Skip to content

Commit 2f129aa

Browse files
authored
Add converter for index.Tensor_hacked_twin (#98)
* Add converter for index.Tensor_hacked_twin -> tcp.gather & tcp.broadcast * Remove the `tcp.custom_op` variants of `Tensor_hacked_twin` * The Tcp variants do not have full coverage of the PyTorch, but we should seek to expand the coverage of our converters * Fix tcp.const when it is a dense resource of ints to cast the value * Add verifier for `tcp.gather` and `tcp.const` to ensure that is used correctly
1 parent 3fc3290 commit 2f129aa

File tree

12 files changed

+187
-75
lines changed

12 files changed

+187
-75
lines changed

docker/Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ RUN apt-get update && \
2323
clang \
2424
clang-format \
2525
gdb \
26-
black
26+
black \
27+
sudo
2728

2829
# Install bazel
2930
ARG ARCH="x86_64"
@@ -42,7 +43,8 @@ WORKDIR /opt/src/mlir-tcp
4243
RUN groupadd -o -g ${GID} ${GROUP} && \
4344
useradd -u ${UID} -g ${GROUP} -ms /bin/bash ${USER} && \
4445
usermod -aG sudo ${USER} && \
45-
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp
46+
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp && \
47+
echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
4648

4749
# Switch to user
4850
USER ${USER}

include/mlir-tcp/Dialect/IR/TcpOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
218218
let assemblyFormat = "attr-dict `:` type($out)";
219219

220220
let hasFolder = 1;
221+
let hasVerifier = 1;
221222
}
222223

223224
def Tcp_BroadcastOp : Tcp_Op<"broadcast", [
@@ -657,6 +658,8 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
657658
);
658659

659660
let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";
661+
662+
let hasVerifier = 1;
660663
}
661664

662665
def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {

lib/Conversion/TorchToTcp/DataMovement.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,79 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
278278
}
279279
};
280280

281+
class ConvertAtenIndexTensorHackedTwin
282+
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
283+
using OpConversionPattern::OpConversionPattern;
284+
285+
LogicalResult
286+
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
287+
ConversionPatternRewriter &rewriter) const override {
288+
// ------- Matching the OP -------
289+
auto self = adaptor.getSelf();
290+
auto selfType = cast<RankedTensorType>(self.getType());
291+
auto indicesList = op.getIndices();
292+
SmallVector<Value> indices;
293+
if (!getListConstructElements(indicesList, indices))
294+
return op.emitError("Failed to match list of indices");
295+
296+
for (unsigned int i = 0; i < indices.size(); i++) {
297+
auto ttype = cast<RankedTensorType>(
298+
getTypeConverter()->convertType(indices[i].getType()));
299+
if (ttype.getRank() != selfType.getRank() - i) {
300+
// Can use tensor.gather instead for this. But will require that there
301+
// are some broadcasting to get the shapes to match what is expected
302+
return failure("Failed to rewrite Tensor_hacked_twin. Need the "
303+
"element gather for this");
304+
}
305+
for (int j = 1; j < ttype.getRank(); j++) {
306+
if (ttype.getShape()[j] != 1)
307+
return failure("Expected the axes >=1 to have size 1");
308+
}
309+
}
310+
311+
// ------ Rewriting the OP ---------
312+
313+
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
314+
indices);
315+
316+
for (unsigned int i = 0; i < indices.size(); i++) {
317+
auto idx = indices[i];
318+
auto ttype = cast<RankedTensorType>(idx.getType());
319+
auto selfType = cast<RankedTensorType>(self.getType());
320+
SmallVector<int64_t> outShape(selfType.getShape());
321+
outShape[i] = ttype.getNumElements();
322+
auto outType = RankedTensorType::get(
323+
outShape, cast<RankedTensorType>(self.getType()).getElementType());
324+
325+
auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(
326+
rewriter, idx, outShape.size() - ttype.getRank());
327+
328+
SmallVector<Value> broadcastValues;
329+
SmallVector<int64_t> broadcastAxes;
330+
for (unsigned int j = 0; j < selfType.getRank(); j++) {
331+
if (j != i) {
332+
broadcastAxes.push_back(j);
333+
broadcastValues.push_back(
334+
rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
335+
}
336+
}
337+
338+
auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
339+
op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()),
340+
expandedShape, broadcastValues,
341+
rewriter.getI64ArrayAttr(broadcastAxes));
342+
343+
auto gather = rewriter.create<tcp::GatherOp>(op.getLoc(), outType, self,
344+
broadcastedShape.getResult(),
345+
rewriter.getIndexAttr(i));
346+
self = gather.getResult();
347+
}
348+
349+
rewriter.replaceOp(op, self);
350+
return success();
351+
}
352+
};
353+
281354
} // namespace
282355

283356
void torch_to_tcp::populateDataMovementPatternsAndLegality(
@@ -294,4 +367,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
294367
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenIndexSelectOp,
295368
AtenIndexSelectOp>(
296369
typeConverter, patterns, target, convertTorchOpsSet);
370+
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<
371+
ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>(
372+
typeConverter, patterns, target, convertTorchOpsSet);
297373
}

lib/Conversion/TorchToTcp/Misc.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "Utils.h"
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "mlir/IR/DialectResourceBlobManager.h"
1920
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2021
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
2122
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@@ -162,6 +163,16 @@ class ConvertValueTensorLiteralOp
162163
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, denseIntAttr);
163164
return success();
164165
}
166+
if (auto elements =
167+
dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
168+
if (resultType.getElementType().isInteger() &&
169+
resultType != adaptor.getValue().getType()) {
170+
auto attr =
171+
DenseResourceElementsAttr::get(resultType, elements.getRawHandle());
172+
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, attr);
173+
return success();
174+
}
175+
}
165176

166177
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType,
167178
adaptor.getValue());

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using namespace mlir::torch;
2929
using namespace mlir::torch::Torch;
3030

3131
namespace {
32+
3233
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
3334
public:
3435
using OpConversionPattern::OpConversionPattern;
@@ -46,33 +47,6 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
4647
return helper.replace();
4748
}
4849
};
49-
50-
class ConvertAtenIndexTensorHackedTwinOp
51-
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
52-
public:
53-
using OpConversionPattern::OpConversionPattern;
54-
55-
LogicalResult
56-
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
57-
ConversionPatternRewriter &rewriter) const override {
58-
59-
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
60-
getTypeConverter()};
61-
62-
Value input = adaptor.getSelf();
63-
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
64-
// Check input is a tensor type.
65-
if (!inputTensorType)
66-
return rewriter.notifyMatchFailure(
67-
op, "Only tensor types input are currently supported");
68-
69-
helper.addOperand("self", input);
70-
helper.addAsMultipleTensorOperands("index_", op.getIndices());
71-
72-
return helper.replace();
73-
}
74-
};
75-
7650
class ConvertAten_IndexPutImplOp
7751
: public OpConversionPattern<Aten_IndexPutImplOp> {
7852
public:
@@ -381,7 +355,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
381355
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
382356
typeConverter, patterns, target, convertTorchOpsSet)
383357
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
384-
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
385358
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
386359
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp);
387360
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(

lib/Conversion/TorchToTcp/Utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
4949
// The parameter input is expected to be of RankedTensorType.
5050
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
5151
Value input, int64_t rankIncrease) {
52+
if (rankIncrease == 0)
53+
return input;
5254
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
5355

5456
SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());

lib/Dialect/IR/TcpOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() {
127127

128128
OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); }
129129

130+
LogicalResult ConstOp::verify() {
131+
if (getValueAttr().getType() != getType())
132+
return emitOpError("can not be used to cast types");
133+
return success();
134+
}
135+
130136
LogicalResult CastOp::verify() {
131137
auto inputType = getIn().getType().cast<RankedTensorType>();
132138
auto outputType = getOut().getType().cast<RankedTensorType>();
@@ -170,6 +176,42 @@ LogicalResult CastOp::verify() {
170176
return success();
171177
}
172178

179+
LogicalResult GatherOp::verify() {
180+
auto inputTensor = cast<RankedTensorType>(getInput().getType());
181+
auto indicesTensor = cast<RankedTensorType>(getIndices().getType());
182+
int64_t gatherDim = getDimAttr().getValue().getSExtValue();
183+
184+
if (inputTensor.getRank() != indicesTensor.getRank())
185+
return emitOpError(
186+
"requires that the input tensor and indices are the same rank");
187+
188+
for (int i = 0; i < inputTensor.getRank(); i++) {
189+
if (inputTensor.getShape()[i] < indicesTensor.getShape()[i] &&
190+
!(inputTensor.getShape()[i] == ShapedType::kDynamic ||
191+
indicesTensor.getShape()[i] == ShapedType::kDynamic ||
192+
i == gatherDim)) {
193+
std::stringstream ss;
194+
ss << "indicies index " << i
195+
<< " expected to be less than or equal to input "
196+
<< " (" << indicesTensor.getShape()[i]
197+
<< " <= " << inputTensor.getShape()[i] << ")";
198+
return emitOpError(ss.str());
199+
}
200+
}
201+
202+
if (getResult().getType().getShape() != indicesTensor.getShape()) {
203+
return emitOpError(
204+
"Expect the shape of the indicies to match the output shape");
205+
}
206+
207+
if (getResult().getType().getElementType() != inputTensor.getElementType()) {
208+
return emitOpError(
209+
"Expect the element type of the return to match the input");
210+
}
211+
212+
return success();
213+
}
214+
173215
//===----------------------------------------------------------------------===//
174216
// BindSymbolicShapeOp
175217
//===----------------------------------------------------------------------===//

test/AotCompile/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ AOT_TEST_SUITE = [
3737
("broadcast_unit_dim_to_dynamic_with_rank_increase", False),
3838
("gather_elements", False),
3939
("gather_slices", False),
40+
("index_hacked_twin", False),
4041
]
4142

4243
py_library(

test/AotCompile/model_loader_lib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
590590
return TorchLoaderOutput(
591591
model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes
592592
)
593+
594+
595+
def index_hacked_twin_loader() -> TorchLoaderOutput:
596+
class Model(torch.nn.Module):
597+
def forward(self, x: torch.Tensor) -> torch.Tensor:
598+
# not using dynamic dim currently as the i1 tensor would ideally
599+
# be generated conditioned on the shape
600+
i1 = torch.tensor([[0], [1], [2], [3]])
601+
return x[i1, [2, 5, 7]]
602+
603+
x = torch.rand(4, 10)
604+
605+
return TorchLoaderOutput(
606+
model=Model(),
607+
inputs=(x,),
608+
)

test/Conversion/TorchToTcp/data_movement.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor
6464
%0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32>
6565
return %0 : !torch.vtensor<[4,2],f32>
6666
}
67+
68+
// -----
69+
70+
// CHECK-label: @torch.aten.index.tensor_hacked_twin
71+
// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0
72+
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32>
73+
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32>
74+
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32>
75+
// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]]
76+
// CHECK: return %[[RET]]
77+
func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {
78+
// there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses
79+
// arange to select all of the
80+
%none = torch.constant.none
81+
%int0 = torch.constant.int 0
82+
%int1 = torch.constant.int 1
83+
%int4 = torch.constant.int 4 // this is a dtype on arange....
84+
%int-1 = torch.constant.int -1
85+
%arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
86+
%arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
87+
%arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>
88+
89+
%l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list<vtensor>
90+
%ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,5,20],f32>
91+
return %ret : !torch.vtensor<[1,5,20],f32>
92+
}

0 commit comments

Comments
 (0)