Skip to content

[MLIR][TORCH] Add op verifier for aten.index_put op #4184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,42 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}

// Check whether the shapes of the tensors are broadcastable or not.
// Two tensors are “broadcastable” if the following rules hold:
// 1.) Each tensor has at least one dimension.
// 2.) When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
static LogicalResult
areStaticallyBroadcastCompatible(ArrayRef<int64_t> shapeA,
ArrayRef<int64_t> shapeB) {
unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size();
unsigned minRank = std::min(rankA, rankB);

for (unsigned i = 0; i < minRank; i++) {
int64_t dimA = shapeA[rankA - i - 1];
int64_t dimB = shapeB[rankB - i - 1];
// Here, we only check the static dimensions for compatibility.
if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize)
continue;
if (!(dimA == dimB || dimA == 1 || dimB == 1))
return failure();
}

return success();
}

// Broadcast the `values` tensor to the slice size created by the list of index
// tensors.
static Value broadcastValuesToSliceSize(Location loc, Value input, Value values,
llvm::ArrayRef<Value> indices,
OpBuilder b) {
static LogicalResult broadcastValuesToSliceSize(Location loc, Value input,
Value values,
llvm::ArrayRef<Value> indices,
OpBuilder b, Value &result) {
auto inputType = cast<ValueTensorType>(input.getType());
ArrayRef<int64_t> inputStaticShape = inputType.getSizes();
auto valuesType = cast<ValueTensorType>(values.getType());
ArrayRef<int64_t> valuesStaticShape = valuesType.getSizes();

// In the case where the input rank is greater than the number of index
// tensors, the remaining dimensions of the input are indexed in their
Expand All @@ -823,12 +851,20 @@ static Value broadcastValuesToSliceSize(Location loc, Value input, Value values,
resultStaticShape.push_back(inputStaticShape[i]);
}

// Check if the values tensor is broadcast compatible with indexing result
// shape or not. Here, we only check the static dimensions the dynamic ones
// will be caught by the downstream lowering.
if (failed(areStaticallyBroadcastCompatible(valuesStaticShape,
resultStaticShape)))
return failure();

auto resultType = b.getType<Torch::ValueTensorType>(
resultStaticShape, valuesType.getOptionalDtype());
Value broadcastShapeList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(b.getType<Torch::IntType>()), resultShape);
return b.create<AtenBroadcastToOp>(loc, resultType, values,
broadcastShapeList);
result =
b.create<AtenBroadcastToOp>(loc, resultType, values, broadcastShapeList);
return success();
}

class ConvertAtenIndexPutHackedTwinOp
Expand Down Expand Up @@ -878,8 +914,12 @@ class ConvertAtenIndexPutHackedTwinOp
if (optionalIndicesCount == 0)
return rewriter.notifyMatchFailure(op, "Indices list must not be empty.");

values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList,
rewriter);
if (failed(broadcastValuesToSliceSize(loc, input, values,
optionalIndicesList, rewriter,
/*result=*/values)))
return rewriter.notifyMatchFailure(
op, "values tensor cannot be broadcast to indexing result shape.");

// Filter to available indices and get the indicesMap:
SmallVector<Value> indicesList;
SmallVector<int64_t> indicesMap;
Expand Down
Loading