Skip to content

Commit 4b6bf0d

Browse files
authored
Support decomposition of torch.broadcast_tensors (#4253)
Added support for torch.broadcast op and decomposition of broadcast_tensor to a sequence of aten.broadcast_to closes #4240
1 parent e65d38e commit 4b6bf0d

File tree

14 files changed

+357
-66
lines changed

14 files changed

+357
-66
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11089,6 +11089,7 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [
1108911089
printDefaultTorchOp(printer, *this, 1, 1);
1109011090
}
1109111091
}];
11092+
let hasFolder = 1;
1109211093
}
1109311094

1109411095
def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [
@@ -12075,6 +12076,29 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
1207512076
let hasFolder = 1;
1207612077
}
1207712078

12079+
def Torch_AtenBroadcastTensorsOp : Torch_Op<"aten.broadcast_tensors", [
12080+
AllowsTypeRefinement,
12081+
HasValueSemantics,
12082+
ReadOnly
12083+
]> {
12084+
let summary = "Generated op for `aten::broadcast_tensors : (Tensor[]) -> (Tensor[])`";
12085+
let arguments = (ins
12086+
AnyTorchListOfTensorType:$tensors
12087+
);
12088+
let results = (outs
12089+
AnyTorchListOfTensorType:$result
12090+
);
12091+
let hasCustomAssemblyFormat = 1;
12092+
let extraClassDefinition = [{
12093+
ParseResult AtenBroadcastTensorsOp::parse(OpAsmParser &parser, OperationState &result) {
12094+
return parseDefaultTorchOp(parser, result, 1, 1);
12095+
}
12096+
void AtenBroadcastTensorsOp::print(OpAsmPrinter &printer) {
12097+
printDefaultTorchOp(printer, *this, 1, 1);
12098+
}
12099+
}];
12100+
}
12101+
1207812102
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
1207912103
AllowsTypeRefinement,
1208012104
HasValueSemantics,

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ Type getBuiltInTypeForTorchScalar(Type type);
6060
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
6161
Type dtype);
6262

63-
// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If
63+
// Checks whether the inputs are broadcast compatible or not. If
6464
// yes, then computes the final broadcast shape.
6565
void computeBroadcastShape(PatternRewriter &rewriter, Location loc,
66-
Value inputA, Value inputB,
66+
ArrayRef<Value> inputs,
6767
SmallVector<int64_t> &resultShape,
6868
SmallVector<Value> &resultShapeValue);
6969

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,9 +1065,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
10651065
} else {
10661066
SmallVector<int64_t> resultBroadcastShapeInt;
10671067
SmallVector<Value> resultBroadcastShapeValue;
1068-
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
1069-
valList[i], resultBroadcastShapeInt,
1070-
resultBroadcastShapeValue);
1068+
Torch::computeBroadcastShape(
1069+
rewriter, binder.getLoc(), {curr, valList[i]},
1070+
resultBroadcastShapeInt, resultBroadcastShapeValue);
10711071
auto baseType = Torch::ValueTensorType::get(
10721072
binder.op->getContext(), resultBroadcastShapeInt,
10731073
resultType.getOptionalDtype());

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,6 +2838,28 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
28382838
return nullptr;
28392839
}
28402840

2841+
//===----------------------------------------------------------------------===//
2842+
// AtenAllBoolOp
2843+
//===----------------------------------------------------------------------===//
2844+
2845+
OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) {
2846+
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
2847+
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
2848+
return nullptr;
2849+
// If all operands are a constant true, return true.
2850+
// If any operands are a constant false, return false
2851+
bool allConstants = true;
2852+
for (auto operand : inputConstruct.getOperands()) {
2853+
bool b;
2854+
if (!matchPattern(operand, m_TorchConstantBool(&b))) {
2855+
allConstants = false;
2856+
} else if (!b) {
2857+
return getI1IntegerAttr(getContext(), false);
2858+
}
2859+
}
2860+
return allConstants ? getI1IntegerAttr(getContext(), true) : nullptr;
2861+
}
2862+
28412863
//===----------------------------------------------------------------------===//
28422864
// AtenFloatScalarOp
28432865
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7810,6 +7810,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
78107810
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78117811
" return %0 : !torch.list<int>\n"
78127812
" }\n"
7813+
" func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list<list<int>>) -> !torch.list<list<int>> {\n"
7814+
" %true = torch.constant.bool true\n"
7815+
" %int0 = torch.constant.int 0\n"
7816+
" %int1 = torch.constant.int 1\n"
7817+
" %0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7818+
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7819+
" %2 = torch.prim.If %1 -> (!torch.list<list<int>>) {\n"
7820+
" %3 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7821+
" torch.prim.If.yield %3 : !torch.list<list<int>>\n"
7822+
" } else {\n"
7823+
" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7824+
" %4 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7825+
" %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7826+
" %6 = torch.prim.Loop %5, %true, init(%3) {\n"
7827+
" ^bb0(%arg1: !torch.int, %arg2: !torch.list<int>):\n"
7828+
" %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7829+
" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7830+
" %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7831+
" torch.prim.Loop.condition %true, iter(%11 : !torch.list<int>)\n"
7832+
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
7833+
" %7 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7834+
" %8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7835+
" torch.prim.Loop %8, %true, init() {\n"
7836+
" ^bb0(%arg1: !torch.int):\n"
7837+
" %9 = torch.aten.append.t %7, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
7838+
" torch.prim.Loop.condition %true, iter()\n"
7839+
" } : (!torch.int, !torch.bool) -> ()\n"
7840+
" torch.prim.If.yield %7 : !torch.list<list<int>>\n"
7841+
" }\n"
7842+
" return %2 : !torch.list<list<int>>\n"
7843+
" }\n"
78137844
" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
78147845
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78157846
" return %0 : !torch.list<int>\n"
@@ -12556,6 +12587,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1255612587
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1255712588
" return %0#1 : !torch.int\n"
1255812589
" }\n"
12590+
" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.list<tuple<int, int>> {\n"
12591+
" %true = torch.constant.bool true\n"
12592+
" %int0 = torch.constant.int 0\n"
12593+
" %0 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12594+
" %1 = torch.prim.Loop %0, %true, init(%int0) {\n"
12595+
" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
12596+
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12597+
" %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12598+
" %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
12599+
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
12600+
" %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12601+
" torch.prim.If.yield %8 : !torch.int\n"
12602+
" } else {\n"
12603+
" torch.prim.If.yield %arg2 : !torch.int\n"
12604+
" }\n"
12605+
" torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n"
12606+
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
12607+
" %2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>\n"
12608+
" %3 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12609+
" torch.prim.Loop %3, %true, init() {\n"
12610+
" ^bb0(%arg1: !torch.int):\n"
12611+
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12612+
" %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12613+
" %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12614+
" %7 = torch.aten.append.t %2, %6 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>\n"
12615+
" torch.prim.Loop.condition %true, iter()\n"
12616+
" } : (!torch.int, !torch.bool) -> ()\n"
12617+
" return %2 : !torch.list<tuple<int, int>>\n"
12618+
" }\n"
1255912619
" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
1256012620
" %int7 = torch.constant.int 7\n"
1256112621
" %int6 = torch.constant.int 6\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "llvm/ADT/StringSet.h"
2525
#include <cstdint>
2626
#include <set>
27-
2827
using namespace mlir;
2928
using namespace mlir::torch;
3029
using namespace mlir::torch::Torch;
@@ -3415,7 +3414,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
34153414
// calculate common shape for broadcast
34163415
SmallVector<int64_t> broadcastShape;
34173416
SmallVector<Value> broadcastShapeValue;
3418-
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
3417+
computeBroadcastShape(rewriter, loc, {self, other}, broadcastShape,
34193418
broadcastShapeValue);
34203419

34213420
Type broadcastType = ValueTensorType::get(
@@ -9109,7 +9108,7 @@ class DecomposeAtenCosineSimilarityOp
91099108
// Broadcast x1 and x2 to the same shape
91109109
SmallVector<int64_t> indexBroadcastShapeInt;
91119110
SmallVector<Value> indexBroadcastShapeValue;
9112-
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
9111+
computeBroadcastShape(rewriter, loc, {x1, x2}, indexBroadcastShapeInt,
91139112
indexBroadcastShapeValue);
91149113
Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
91159114
Type broadcastType = ValueTensorType::get(
@@ -11482,7 +11481,7 @@ class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
1148211481
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
1148311482
SmallVector<int64_t> broadcastShape;
1148411483
SmallVector<Value> broadcastShapeValue;
11485-
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11484+
computeBroadcastShape(rewriter, loc, {input, value}, broadcastShape,
1148611485
broadcastShapeValue);
1148711486

1148811487
auto broadcastType = ValueTensorType::get(
@@ -12580,6 +12579,52 @@ class DecomposeAtenRoundDecimalsOp
1258012579
};
1258112580
} // namespace
1258212581

12582+
namespace {
12583+
class DecomposeAtenBroadcastTensorsOp
12584+
: public OpRewritePattern<AtenBroadcastTensorsOp> {
12585+
public:
12586+
using OpRewritePattern::OpRewritePattern;
12587+
LogicalResult matchAndRewrite(AtenBroadcastTensorsOp op,
12588+
PatternRewriter &rewriter) const override {
12589+
12590+
Location loc = op.getLoc();
12591+
SmallVector<Value> tensors;
12592+
if (!getListConstructElements(op.getTensors(), tensors))
12593+
return rewriter.notifyMatchFailure(op, "Unable to get tensors");
12594+
int64_t numTensors = tensors.size();
12595+
12596+
SmallVector<int64_t> broadcastShape;
12597+
SmallVector<Value> broadcastShapeValue;
12598+
12599+
computeBroadcastShape(rewriter, loc, tensors, broadcastShape,
12600+
broadcastShapeValue);
12601+
12602+
auto resType = cast<BaseTensorType>(tensors[0].getType());
12603+
auto dtype = resType.getDtype();
12604+
Type broadcastType = ValueTensorType::get(
12605+
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);
12606+
12607+
Value broadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
12608+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
12609+
broadcastShapeValue);
12610+
12611+
SmallVector<Value> broadcastedValues;
12612+
for (int64_t i = 0; i < numTensors; i++) {
12613+
auto inputTensor = tensors[i];
12614+
auto broadcastedVal = rewriter.create<AtenBroadcastToOp>(
12615+
loc, broadcastType, inputTensor, broadcastShapeTorchList);
12616+
broadcastedValues.push_back(broadcastedVal);
12617+
}
12618+
12619+
auto broadcastedValuesList = rewriter.create<Torch::PrimListConstructOp>(
12620+
loc, Torch::ListType::get(broadcastType), broadcastedValues);
12621+
12622+
rewriter.replaceOp(op, broadcastedValuesList);
12623+
return success();
12624+
}
12625+
};
12626+
} // namespace
12627+
1258312628
namespace {
1258412629
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
1258512630
public:
@@ -12713,8 +12758,8 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
1271312758
// calculate common shape for broadcast
1271412759
SmallVector<int64_t> broadcastShape;
1271512760
SmallVector<Value> broadcastShapeValue;
12716-
computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12717-
broadcastShapeValue);
12761+
computeBroadcastShape(rewriter, loc, {finalIndices, index},
12762+
broadcastShape, broadcastShapeValue);
1271812763
Type broadcastType = ValueTensorType::get(
1271912764
context, llvm::ArrayRef(broadcastShape), si64Type);
1272012765

@@ -12974,6 +13019,7 @@ class DecomposeComplexOpsPass
1297413019
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
1297513020
addPatternIfTargetOpIsIllegal<
1297613021
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
13022+
addPatternIfTargetOpIsIllegal<DecomposeAtenBroadcastTensorsOp>(patterns);
1297713023
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
1297813024
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
1297913025
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
520520
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
521521
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
522522
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
523+
target.addIllegalOp<AtenBroadcastTensorsOp>();
523524
target.addIllegalOp<AtenClampMinOp>();
524525
target.addIllegalOp<AtenClampMinTensorOp>();
525526
target.addIllegalOp<AtenClampMaxOp>();

0 commit comments

Comments
 (0)