Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11041,6 +11041,7 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,24 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenAllBoolOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) {
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
return nullptr;
// If all operands are a constant true, return true.
for (auto operand : inputConstruct.getOperands()) {
bool b = true;
if (!matchPattern(operand, m_TorchConstantBool(&b)) || !b) {
return nullptr;
}
}
return getI1IntegerAttr(getContext(), true);
}

//===----------------------------------------------------------------------===//
// AtenFloatScalarOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11328,7 +11328,7 @@ class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
computeBroadcastShape(rewriter, loc, {input, value}, broadcastShape,
broadcastShapeValue);

auto broadcastType = ValueTensorType::get(
Expand Down
53 changes: 25 additions & 28 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,

SmallVector<SmallVector<int64_t>> shapes;
SmallVector<unsigned> ranks;
SmallVector<Value> maxShapeValues;

for (auto input : inputs) {
SmallVector<int64_t> shape{
Expand All @@ -496,6 +497,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
ranks.push_back(shape.size());
}

Value torchCstOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
unsigned maxRank = *std::max_element(ranks.begin(), ranks.end());

// Check whether the shapes of the tensors are broadcastable or not.
Expand All @@ -517,23 +520,34 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
}
}

Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
// Compute shape value of broadcast result,
// which is the maximum of dimension sizes across all inputs
Value maxShapeVal = sizeInputs.front();
for (auto sizeInput : sizeInputs) {
maxShapeVal = rewriter.create<PrimMaxIntOp>(loc, maxShapeVal, sizeInput);
}
maxShapeValues.push_back(maxShapeVal);

SmallVector<Value> predicates;
for (auto sizeVal : sizeInputs) {
Value cmpSizeEquals =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, sizeInputs.front());
predicates.push_back(cmpSizeEquals);
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, maxShapeVal);
Value cmpSizeEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, torchCstOne);
predicates.push_back(cmpSizeEqualsOne);
Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(cmpSizeEquals.getType()),
SmallVector<Value>{cmpSizeEquals, cmpSizeEqualsOne});
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
predicates.push_back(cmp);
}

Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(predicates.front().getType()), predicates);
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "tensors are not broadcast compatible");
if (!predicates.empty()) {
Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(predicates.front().getType()), predicates);
Value cmp = rewriter.create<Torch::AtenAllBoolOp>(loc, anyBoolOpList);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "tensors are not broadcast compatible");
}
}

// If we reach here then it means both the shapes are broadcast compatible.
Expand All @@ -551,24 +565,7 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
unsigned resultRank = resultShape.size();
for (unsigned i = 0; i < maxRank; i++) {

SmallVector<Value> sizeInputs;
for (auto [idx, input] : llvm::enumerate(inputs)) {
int sizeDimIdx = ranks[idx] - i - 1;
if (sizeDimIdx >= 0) {
auto sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sizeDimIdx));
sizeInputs.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim));
}
}

// Compute shape value of broadcast result,
// which is the maximum of dimension sizes across all inputs
Value maxShapeVal = sizeInputs.front();
for (auto sizeInput : sizeInputs) {
maxShapeVal = rewriter.create<PrimMaxIntOp>(loc, maxShapeVal, sizeInput);
}
resultShapeValue[resultRank - i - 1] = maxShapeVal;
resultShapeValue[resultRank - i - 1] = maxShapeValues[i];

// Compute result shape if all input shapes are known
bool unknownSize = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::isneginf : (Tensor) -> (Tensor)")
emit("aten::isposinf : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::all.bool : (bool[]) -> (bool)")
emit("aten::all.bool : (bool[]) -> (bool)", has_folder=True)
emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::any : (Tensor) -> (Tensor)")
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2685,6 +2685,16 @@ func.func @torch.aten.any.bool$fold() -> !torch.bool {
return %0 : !torch.bool
}

// CHECK-LABEL: func.func @torch.aten.all.bool$fold() -> !torch.bool {
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[CST_TRUE]] : !torch.bool
func.func @torch.aten.all.bool$fold() -> !torch.bool {
%true = torch.constant.bool true
%input = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%0 = torch.aten.all.bool %input : !torch.list<bool> -> !torch.bool
return %0 : !torch.bool
}

// CHECK-LABEL: func.func @torch.aten.floor$canonicalize
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64>
Expand Down
25 changes: 13 additions & 12 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -849,18 +849,19 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf

// -----

// CHECK-LABEL: func.func @torch.aten.broadcast_tensors(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[2,1],f32>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: %[[VAR1:.*]] = torch.constant.int 2
// CHECK: %[[VAR2:.*]] = torch.constant.int 3
// CHECK: %[[VAR3:.*]] = torch.constant.bool true
// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible"
// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible"
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[VAR1]], %[[VAR2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[ARG0:.*]], %[[VAR4]] : !torch.vtensor<[1,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[ARG1:.*]], %[[VAR4]] : !torch.vtensor<[2,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[VAR7:.*]] = torch.prim.ListConstruct %[[VAR5]], %[[VAR6]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: return %[[VAR7]] : !torch.list<vtensor<[2,3],f32>>
// CHECK-LABEL: func.func @torch.aten.broadcast_tensors
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32>
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible"
// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible"
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[B0:.*]] = torch.aten.broadcast_to %[[ARG0]], %[[SHAPE]] : !torch.vtensor<[1,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[B1:.*]] = torch.aten.broadcast_to %[[ARG1]], %[[SHAPE]] : !torch.vtensor<[2,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: return %[[LIST]] : !torch.list<vtensor<[2,3],f32>>
func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list<vtensor<[2,3], f32>> {
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list<vtensor>
%1 = torch.aten.broadcast_tensors %0 : !torch.list<vtensor> -> !torch.list<vtensor<[2,3],f32>>
Expand Down
Loading