Skip to content

Support decomposition of torch.broadcast_tensors #4253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11973,6 +11973,29 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
let hasFolder = 1;
}

def Torch_AtenBroadcastTensorsOp : Torch_Op<"aten.broadcast_tensors", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::broadcast_tensors : (Tensor[]) -> (Tensor[])`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBroadcastTensorsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenBroadcastTensorsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
4 changes: 2 additions & 2 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Type getBuiltInTypeForTorchScalar(Type type);
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype);

// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If
// Checks whether the inputs are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void computeBroadcastShape(PatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<Value> inputs,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue);

Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,9 +1065,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
} else {
SmallVector<int64_t> resultBroadcastShapeInt;
SmallVector<Value> resultBroadcastShapeValue;
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
valList[i], resultBroadcastShapeInt,
resultBroadcastShapeValue);
Torch::computeBroadcastShape(
rewriter, binder.getLoc(), {curr, valList[i]},
resultBroadcastShapeInt, resultBroadcastShapeValue);
auto baseType = Torch::ValueTensorType::get(
binder.op->getContext(), resultBroadcastShapeInt,
resultType.getOptionalDtype());
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7796,6 +7796,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list<list<int>>) -> !torch.list<list<int>> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<list<int>>) {\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" torch.prim.If.yield %3 : !torch.list<list<int>>\n"
" } else {\n"
" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %6 = torch.prim.Loop %5, %true, init(%3) {\n"
" ^bb0(%arg1: !torch.int, %arg2: !torch.list<int>):\n"
" %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter(%11 : !torch.list<int>)\n"
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
" %7 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %8, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %9 = torch.aten.append.t %7, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield %7 : !torch.list<list<int>>\n"
" }\n"
" return %2 : !torch.list<list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12407,6 +12438,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.list<tuple<int, int>> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %1 = torch.prim.Loop %0, %true, init(%int0) {\n"
" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %8 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg2 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" %7 = torch.aten.append.t %2, %6 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" return %2 : !torch.list<tuple<int, int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
" %int7 = torch.constant.int 7\n"
" %int6 = torch.constant.int 6\n"
Expand Down
52 changes: 49 additions & 3 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include <set>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
Expand Down Expand Up @@ -3415,7 +3414,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
// calculate common shape for broadcast
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
computeBroadcastShape(rewriter, loc, {self, other}, broadcastShape,
broadcastShapeValue);

Type broadcastType = ValueTensorType::get(
Expand Down Expand Up @@ -8962,7 +8961,7 @@ class DecomposeAtenCosineSimilarityOp
// Broadcast x1 and x2 to the same shape
SmallVector<int64_t> indexBroadcastShapeInt;
SmallVector<Value> indexBroadcastShapeValue;
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
computeBroadcastShape(rewriter, loc, {x1, x2}, indexBroadcastShapeInt,
indexBroadcastShapeValue);
Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
Type broadcastType = ValueTensorType::get(
Expand Down Expand Up @@ -12203,6 +12202,52 @@ class DecomposeAtenRoundDecimalsOp
};
} // namespace

namespace {
class DecomposeAtenBroadcastTensorsOp
: public OpRewritePattern<AtenBroadcastTensorsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBroadcastTensorsOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors))
return rewriter.notifyMatchFailure(op, "Unable to get tensors");
int64_t numTensors = tensors.size();

SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;

computeBroadcastShape(rewriter, loc, tensors, broadcastShape,
broadcastShapeValue);

auto resType = cast<BaseTensorType>(tensors[0].getType());
auto dtype = resType.getDtype();
Type broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);

Value broadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);

SmallVector<Value> broadcastedValues;
for (int64_t i = 0; i < numTensors; i++) {
auto inputTensor = tensors[i];
auto broadcastedVal = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, inputTensor, broadcastShapeTorchList);
broadcastedValues.push_back(broadcastedVal);
}

auto broadcastedValuesList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(broadcastType), broadcastedValues);

rewriter.replaceOp(op, broadcastedValuesList);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -12403,6 +12448,7 @@ class DecomposeComplexOpsPass
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBroadcastTensorsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
target.addIllegalOp<AtenBroadcastTensorsOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMinTensorOp>();
target.addIllegalOp<AtenClampMaxOp>();
Expand Down
129 changes: 84 additions & 45 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,78 +479,117 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
return unsqueezed;
}

// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
// Checks whether the inputs are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<Value> inputs,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue) {
SmallVector<int64_t> shapeA{
cast<BaseTensorType>(inputA.getType()).getSizes()};
SmallVector<int64_t> shapeB{
cast<BaseTensorType>(inputB.getType()).getSizes()};
unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size();
unsigned minRank = rankA > rankB ? rankB : rankA;

SmallVector<SmallVector<int64_t>> shapes;
SmallVector<unsigned> ranks;

for (auto input : inputs) {
SmallVector<int64_t> shape{
cast<BaseTensorType>(input.getType()).getSizes()};
shapes.push_back(shape);
ranks.push_back(shape.size());
}

unsigned maxRank = *std::max_element(ranks.begin(), ranks.end());

// Check whether the shapes of the tensors are broadcastable or not.
// Two tensors are “broadcastable” if the following rules hold:
// 1.) Each tensor has at least one dimension.
// 2.) When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
for (unsigned i = 0; i < maxRank; i++) {

SmallVector<Value> sizeInputs;
for (auto [idx, input] : llvm::enumerate(inputs)) {
int sizeDimIdx = ranks[idx] - i - 1;
if (sizeDimIdx >= 0) {
auto sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sizeDimIdx));
sizeInputs.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim));
}
}

Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cmpSizeAEqualsSizeB =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, sizeInputB);
Value cmpSizeAEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, torchCstOne);
Value cmpSizeBEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputB, torchCstOne);
SmallVector<Value> predicates;
for (auto sizeVal : sizeInputs) {
Value cmpSizeEquals =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, sizeInputs.front());
predicates.push_back(cmpSizeEquals);
Value cmpSizeEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, torchCstOne);
predicates.push_back(cmpSizeEqualsOne);
}

Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()),
SmallVector<Value>{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne,
cmpSizeBEqualsOne});
loc, Torch::ListType::get(predicates.front().getType()), predicates);
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "tensors are not broadcast compatible");
}

// If we reach here then it means both the shapes are broadcast compatible.
resultShape = rankA >= rankB ? shapeA : shapeB;
Value shapeTensor = rankA >= rankB ? inputA : inputB;
auto maxRankIdx =
std::max_element(ranks.begin(), ranks.end()) - ranks.begin();
resultShape = shapes[maxRankIdx];
Value shapeTensor = inputs[maxRankIdx];

for (unsigned i = 0; i < resultShape.size(); i++) {
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
resultShapeValue.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, shapeTensor, sizeDim));
}

unsigned resultRank = resultShape.size();
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
resultShapeValue[resultRank - i - 1] =
rewriter.create<PrimMaxIntOp>(loc, sizeInputA, sizeInputB);
if (shapeA[rankA - i - 1] == kUnknownSize ||
shapeB[rankB - i - 1] == kUnknownSize) {
for (unsigned i = 0; i < maxRank; i++) {

SmallVector<Value> sizeInputs;
for (auto [idx, input] : llvm::enumerate(inputs)) {
int sizeDimIdx = ranks[idx] - i - 1;
if (sizeDimIdx >= 0) {
auto sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sizeDimIdx));
sizeInputs.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim));
}
}

// Compute shape value of broadcast result,
// which is the maximum of dimension sizes across all inputs
Value maxShapeVal = sizeInputs.front();
for (auto sizeInput : sizeInputs) {
maxShapeVal = rewriter.create<PrimMaxIntOp>(loc, maxShapeVal, sizeInput);
}
resultShapeValue[resultRank - i - 1] = maxShapeVal;

// Compute result shape if all input shapes are known
bool unknownSize = false;
for (auto [idx, shape] : llvm::enumerate(shapes)) {
if (ranks[idx] - i - 1 < shape.size() &&
shape[ranks[idx] - i - 1] == kUnknownSize) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing. 0 <= i < maxRank, so it seems redundant to check ranks[idx] - i - 1 < ranks[idx]. Perhaps the check should be ranks[idx] - i - 1 >= 0 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey , this check ensures that out of bounds access to shape is not performed in case input tensors have different ranks. The e2e test BroadcastTensorsModuleList_multiple_ranks covers this case. Let me know if you have a suggestion on making this check easier to read

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "confusing" isn't the correct word. I'm rather certain the check is wrong.

E.g. For ranks = {1, 2}, maxRank = 2, then when idx = 0 and i=1, then ranks[idx] - i - 1 = 1 - 1 - 1 = -1, which you are using to access shape for the tensor at idx = 0.

Whereas, ranks[idx] - i - 1 < ranks[idx] is automatic from the fact that i >= 0.

unknownSize = true;
}
}

if (unknownSize) {
resultShape[resultRank - i - 1] = kUnknownSize;
} else {
resultShape[resultRank - i - 1] =
std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]);

int64_t maxShape = 1;
for (auto [idx, shape] : llvm::enumerate(shapes)) {
if (ranks[idx] - i - 1 < shape.size()) {
maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]);
}
}
resultShape[resultRank - i - 1] = maxShape;
}
}
}
Expand Down
Loading
Loading