Skip to content

Commit 0f4b396

Browse files
authored
Add tcp.gather_nd and rework index.Tensor_hacked_twin to use gather_nd (#101)
Add tcp.gather_nd and rework index.Tensor_hacked_twin to use gather_nd
1 parent ee58041 commit 0f4b396

File tree

8 files changed

+377
-55
lines changed

8 files changed

+377
-55
lines changed

include/mlir-tcp/Dialect/IR/TcpOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,30 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
662662
let hasVerifier = 1;
663663
}
664664

665+
def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> {
666+
667+
let summary = "Gather elements from input based on indices over multiple dimensions";
668+
669+
let description = [{
670+
Gathers elements from a given tensor based on indices that index along multiple dimensions.
671+
672+
More details regarding this op: docs/gather.md
673+
}];
674+
675+
let arguments = (ins
676+
Tcp_Tensor:$input,
677+
Tcp_IntTensor:$indices
678+
);
679+
680+
let results = (outs
681+
Tcp_Tensor:$out
682+
);
683+
684+
let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";
685+
686+
let hasVerifier = 1;
687+
}
688+
665689
def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {
666690

667691
let summary = "Extracts a slice of the input tensor";

lib/Conversion/TcpToLinalg/DataMovement.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,101 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
9191
}
9292
};
9393

94+
/**
95+
* tcp.gather_nd is lowered to linalg.generic, which allows us to define every
96+
* element in the result tensor using a programmatic expression. The last
97+
* dimension of the indicies tensor is used to index into the input tensor.
98+
*
99+
* For example, we have an indices tensor of shape 9x4x3x2 and an input
100+
* tensor of shape 5x6x7x8, then the resulting tensor will be of shape
101+
* 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used
102+
* to index into the indicies tensor. Then the last dimension of the index
103+
* tensor (the 2 sized dimension) is used to index into the input tensor.
104+
*/
105+
class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
106+
public:
107+
using OpConversionPattern::OpConversionPattern;
108+
109+
LogicalResult
110+
matchAndRewrite(GatherNDOp op, OpAdaptor adaptor,
111+
ConversionPatternRewriter &rewriter) const override {
112+
Location loc = op->getLoc();
113+
auto resultTensorType = getTypeConverter()
114+
->convertType(op.getOut().getType())
115+
.cast<RankedTensorType>();
116+
117+
auto inputTensor = adaptor.getInput();
118+
auto indicesTensor = adaptor.getIndices();
119+
auto indicesType = cast<RankedTensorType>(indicesTensor.getType());
120+
auto inputType = cast<RankedTensorType>(inputTensor.getType());
121+
int numGatherAxes = indicesType.getShape().back();
122+
123+
SmallVector<Value> resultDimSizes;
124+
for (int i = 0; i < indicesType.getRank() - 1; i++) {
125+
resultDimSizes.push_back(
126+
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i));
127+
}
128+
for (int i = numGatherAxes; i < inputType.getRank(); i++) {
129+
resultDimSizes.push_back(
130+
rewriter.createOrFold<tensor::DimOp>(loc, inputTensor, i));
131+
}
132+
133+
assert(resultDimSizes.size() == resultTensorType.getRank());
134+
135+
Value emptyTensor =
136+
rewriter.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes),
137+
resultTensorType.getElementType());
138+
139+
auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
140+
SmallVector<Value> valueIndices, gatherIndices;
141+
for (int i = 0; i < indicesType.getRank() - 1; i++) {
142+
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
143+
b.getI64IntegerAttr(i));
144+
gatherIndices.push_back(idx);
145+
}
146+
for (int i = 0; i < numGatherAxes; i++) {
147+
SmallVector<Value> gi = gatherIndices;
148+
auto gidx = b.create<arith::ConstantOp>(loc, b.getIndexAttr(i));
149+
gi.push_back(gidx);
150+
assert(gi.size() == indicesType.getRank());
151+
auto idxExtract = b.create<tensor::ExtractOp>(
152+
loc, indicesType.getElementType(), indicesTensor, gi);
153+
auto idxCast =
154+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), idxExtract);
155+
valueIndices.push_back(idxCast);
156+
}
157+
for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank();
158+
i++) {
159+
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
160+
b.getI64IntegerAttr(i));
161+
valueIndices.push_back(idx);
162+
}
163+
assert(valueIndices.size() == inputType.getRank());
164+
auto extract =
165+
b.create<tensor::ExtractOp>(loc, resultTensorType.getElementType(),
166+
inputTensor, valueIndices)
167+
.getResult();
168+
169+
b.create<linalg::YieldOp>(loc, extract);
170+
};
171+
172+
SmallVector<Value> empty;
173+
SmallVector<AffineMap> indexingMaps;
174+
indexingMaps.push_back(
175+
rewriter.getMultiDimIdentityMap(resultTensorType.getRank()));
176+
SmallVector<utils::IteratorType> iteratorTypes(
177+
resultTensorType.getRank(), utils::IteratorType::parallel);
178+
179+
auto generic = rewriter.create<linalg::GenericOp>(
180+
loc, resultTensorType, empty, emptyTensor, indexingMaps, iteratorTypes,
181+
bodyBuilder);
182+
183+
rewriter.replaceOp(op, generic.getResult(0));
184+
185+
return success();
186+
}
187+
};
188+
94189
} // namespace
95190

96191
void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(
@@ -100,4 +195,6 @@ void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(
100195

101196
target.addIllegalOp<GatherOp>();
102197
patterns.add<ConvertGatherOp>(typeConverter, context);
198+
target.addIllegalOp<GatherNDOp>();
199+
patterns.add<ConvertGatherNDOp>(typeConverter, context);
103200
}

lib/Conversion/TorchToTcp/DataMovement.cpp

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -278,75 +278,73 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
278278
}
279279
};
280280

281+
/**
282+
* The index.Tensor_hacked_twin takes a list of tensors which have to be
283+
* broadcast together to be the same shape, and then those are fed into a
284+
* gather which will select the different axes
285+
*/
281286
class ConvertAtenIndexTensorHackedTwin
282287
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
283288
using OpConversionPattern::OpConversionPattern;
284289

285290
LogicalResult
286291
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
287292
ConversionPatternRewriter &rewriter) const override {
288-
// ------- Matching the OP -------
289293
auto self = adaptor.getSelf();
290-
auto selfType = cast<RankedTensorType>(self.getType());
291294
auto indicesList = op.getIndices();
292295
SmallVector<Value> indices;
293296
if (!getListConstructElements(indicesList, indices))
294297
return op.emitError("Failed to match list of indices");
295298

296-
for (unsigned int i = 0; i < indices.size(); i++) {
297-
auto ttype = cast<RankedTensorType>(
298-
getTypeConverter()->convertType(indices[i].getType()));
299-
if (ttype.getRank() != selfType.getRank() - i) {
300-
// Can use tensor.gather instead for this. But will require that there
301-
// are some broadcasting to get the shapes to match what is expected
302-
return failure("Failed to rewrite Tensor_hacked_twin. Need the "
303-
"element gather for this");
304-
}
305-
for (int j = 1; j < ttype.getRank(); j++) {
306-
if (ttype.getShape()[j] != 1)
307-
return failure("Expected the axes >=1 to have size 1");
299+
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
300+
indices);
301+
302+
if (auto indiciesBroadcasted = torch_to_tcp::broadcastManyToMatchShape(
303+
rewriter, op.getLoc(), indices)) {
304+
indices = indiciesBroadcasted.value();
305+
} else {
306+
return failure("failed to broadcast the shapes of the input indicies");
307+
}
308+
309+
for (int i = 0; i < indices.size(); i++) {
310+
Value v =
311+
torch_to_tcp::broadcastRankInTrailingDims(rewriter, indices[i], 1);
312+
if (!cast<RankedTensorType>(v.getType()).getElementType().isInteger(64)) {
313+
v = rewriter.createOrFold<tcp::CastOp>(
314+
op.getLoc(),
315+
RankedTensorType::get(
316+
cast<RankedTensorType>(v.getType()).getShape(),
317+
rewriter.getI64Type()),
318+
v, SignednessAttr::get(op->getContext(), Signedness::Signed),
319+
SignednessAttr::get(op->getContext(), Signedness::Signless));
308320
}
321+
indices[i] = v;
309322
}
310323

311-
// ------ Rewriting the OP ---------
324+
auto indicesType = cast<RankedTensorType>(indices[0].getType());
325+
int indicesRank = indicesType.getRank();
326+
SmallVector<int64_t> outIndexShape;
327+
outIndexShape.insert(outIndexShape.begin(), indicesType.getShape().begin(),
328+
indicesType.getShape().end());
329+
outIndexShape.back() = indices.size();
312330

313-
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
314-
indices);
331+
auto outIndexType =
332+
RankedTensorType::get(outIndexShape, indicesType.getElementType());
333+
auto indexTensor =
334+
rewriter
335+
.create<tensor::ConcatOp>(
336+
op.getLoc(), outIndexType,
337+
rewriter.getI64IntegerAttr(indicesRank - 1), indices)
338+
.getResult();
315339

316-
for (unsigned int i = 0; i < indices.size(); i++) {
317-
auto idx = indices[i];
318-
auto ttype = cast<RankedTensorType>(idx.getType());
319-
auto selfType = cast<RankedTensorType>(self.getType());
320-
SmallVector<int64_t> outShape(selfType.getShape());
321-
outShape[i] = ttype.getNumElements();
322-
auto outType = RankedTensorType::get(
323-
outShape, cast<RankedTensorType>(self.getType()).getElementType());
324-
325-
auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(
326-
rewriter, idx, outShape.size() - ttype.getRank());
327-
328-
SmallVector<Value> broadcastValues;
329-
SmallVector<int64_t> broadcastAxes;
330-
for (unsigned int j = 0; j < selfType.getRank(); j++) {
331-
if (j != i) {
332-
broadcastAxes.push_back(j);
333-
broadcastValues.push_back(
334-
rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
335-
}
336-
}
340+
auto outType =
341+
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
337342

338-
auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
339-
op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()),
340-
expandedShape, broadcastValues,
341-
rewriter.getI64ArrayAttr(broadcastAxes));
343+
auto gatherOp = rewriter.create<tcp::GatherNDOp>(op.getLoc(), outType, self,
344+
indexTensor);
342345

343-
auto gather = rewriter.create<tcp::GatherOp>(op.getLoc(), outType, self,
344-
broadcastedShape.getResult(),
345-
rewriter.getIndexAttr(i));
346-
self = gather.getResult();
347-
}
346+
rewriter.replaceOp(op, gatherOp);
348347

349-
rewriter.replaceOp(op, self);
350348
return success();
351349
}
352350
};

lib/Conversion/TorchToTcp/Utils.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,32 @@ Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
7272
input.getDefiningOp()->getLoc(), resultType, input, reassociationMap);
7373
}
7474

75+
// The parameter input is expected to be of RankedTensorType.
76+
Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter,
77+
Value input, int64_t rankIncrease) {
78+
if (rankIncrease == 0)
79+
return input;
80+
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
81+
82+
SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());
83+
if (inputType.getRank() > 0) {
84+
for (int64_t inputAxis = 0; inputAxis < inputType.getRank(); inputAxis++)
85+
reassociationMap[inputAxis].push_back(
86+
rewriter.getAffineDimExpr(inputAxis));
87+
for (int64_t axis = 0; axis < rankIncrease; axis++)
88+
reassociationMap.back().push_back(
89+
rewriter.getAffineDimExpr(axis + inputType.getRank()));
90+
}
91+
92+
SmallVector<int64_t> resultShape(inputType.getShape());
93+
resultShape.insert(resultShape.end(), rankIncrease, 1);
94+
auto resultType =
95+
inputType.cloneWith(ArrayRef(resultShape), inputType.getElementType());
96+
97+
return rewriter.create<tensor::ExpandShapeOp>(
98+
input.getDefiningOp()->getLoc(), resultType, input, reassociationMap);
99+
}
100+
75101
Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input,
76102
int64_t targetRank, int64_t axisInOutput) {
77103
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
@@ -130,6 +156,98 @@ Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input,
130156
axesAttr);
131157
}
132158

159+
// the parameter values is expected to be an array of RankedTensorType tensors
160+
std::optional<SmallVector<Value>>
161+
broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc,
162+
ValueRange values) {
163+
if (values.size() <= 1) {
164+
return values;
165+
}
166+
SmallVector<Value> ret;
167+
168+
int64_t maxRank = 0;
169+
for (auto v : values) {
170+
assert(isa<RankedTensorType>(v.getType()) && "assert 1");
171+
auto t = cast<RankedTensorType>(v.getType());
172+
if (t.getRank() > maxRank)
173+
maxRank = t.getRank();
174+
}
175+
176+
for (auto v : values) {
177+
auto type = cast<RankedTensorType>(v.getType());
178+
v = broadcastRankInLeadingDims(rewriter, v, maxRank - type.getRank());
179+
ret.push_back(v);
180+
}
181+
182+
// figure out what the shape should be for each dim
183+
struct DimInfo {
184+
Value value;
185+
bool found = false;
186+
int64_t staticValue = 1;
187+
};
188+
SmallVector<DimInfo> resultShape(maxRank);
189+
190+
for (auto v : ret) {
191+
auto t = cast<RankedTensorType>(v.getType());
192+
auto shape = t.getShape();
193+
for (int64_t i = 0; i < maxRank; i++) {
194+
if (shape[i] != 1) {
195+
// meaning that this is not something that is already 1, and therefore
196+
// would get broadcast
197+
if (resultShape[i].found) {
198+
// then there are multiple inputs which have non-1 values for this
199+
// axis we should check that the size is the same. If there are
200+
// different shapes then this would result in an error when
201+
// broadcasting
202+
if (shape[i] != ShapedType::kDynamic &&
203+
resultShape[i].staticValue != ShapedType::kDynamic &&
204+
resultShape[i].staticValue != shape[i]) {
205+
// the broadcast failed as there are two different shapes for this
206+
llvm::errs()
207+
<< "failed with broadcasting, have two different shapes "
208+
<< shape[i] << " " << resultShape[i].staticValue << "\n";
209+
return {};
210+
}
211+
} else {
212+
resultShape[i].found = true;
213+
if (shape[i] == ShapedType::kDynamic) {
214+
resultShape[i].value = rewriter.create<tensor::DimOp>(loc, v, i);
215+
resultShape[i].staticValue = ShapedType::kDynamic;
216+
} else {
217+
resultShape[i].value = rewriter.create<arith::ConstantOp>(
218+
loc, rewriter.getIndexAttr(shape[i]));
219+
resultShape[i].staticValue = shape[i];
220+
}
221+
}
222+
}
223+
}
224+
}
225+
226+
// do the broadcasts into the shapes
227+
for (int64_t i = 0; i < ret.size(); i++) {
228+
auto v = ret[i];
229+
auto t = cast<RankedTensorType>(v.getType());
230+
SmallVector<int64_t> axes;
231+
SmallVector<Value> sizes;
232+
SmallVector<int64_t> staticShape;
233+
for (int64_t j = 0; j < maxRank; j++) {
234+
if (t.getShape()[j] == 1 && resultShape[j].found) {
235+
axes.push_back(j);
236+
sizes.push_back(resultShape[j].value);
237+
}
238+
staticShape.push_back(resultShape[j].staticValue);
239+
}
240+
if (!axes.empty()) {
241+
// there is something to broadcast here, so add the op
242+
Type resultType = t.cloneWith(staticShape, t.getElementType());
243+
ret[i] = rewriter.create<tcp::BroadcastOp>(
244+
loc, resultType, ret[i], sizes, rewriter.getI64ArrayAttr(axes));
245+
}
246+
}
247+
248+
return ret;
249+
}
250+
133251
// The parameters input are expected to be of RankedTensorType.
134252
std::pair<Value, Value>
135253
broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs,

0 commit comments

Comments
 (0)