Skip to content

Commit 06d1789

Browse files
authored
[Torch Dialect] Allow simplification of shape calculations of aten.tile, col2im, aten.stft (#3785)
- Add `aten.mul.left_t` (+ canonicalizer) to allow simplification of aten.tile. - Change syntax of the computation of col2im shape to allow the use of an already existing canonicalization pattern (for `aten.add.t`) for its simplification. - Add `aten.eq.bool` ( + folder) to allow simplification of aten.stft.
1 parent 1201bab commit 06d1789

File tree

6 files changed

+162
-12
lines changed

6 files changed

+162
-12
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16278,6 +16278,31 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
1627816278
}];
1627916279
}
1628016280

16281+
def Torch_AtenEqBoolOp : Torch_Op<"aten.eq.bool", [
16282+
AllowsTypeRefinement,
16283+
HasValueSemantics,
16284+
ReadOnly
16285+
]> {
16286+
let summary = "Generated op for `aten::eq.bool : (bool, bool) -> (bool)`";
16287+
let arguments = (ins
16288+
Torch_BoolType:$a,
16289+
Torch_BoolType:$b
16290+
);
16291+
let results = (outs
16292+
Torch_BoolType:$result
16293+
);
16294+
let hasCustomAssemblyFormat = 1;
16295+
let extraClassDefinition = [{
16296+
ParseResult AtenEqBoolOp::parse(OpAsmParser &parser, OperationState &result) {
16297+
return parseDefaultTorchOp(parser, result, 2, 1);
16298+
}
16299+
void AtenEqBoolOp::print(OpAsmPrinter &printer) {
16300+
printDefaultTorchOp(printer, *this, 2, 1);
16301+
}
16302+
}];
16303+
let hasFolder = 1;
16304+
}
16305+
1628116306
def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [
1628216307
AllowsTypeRefinement,
1628316308
HasValueSemantics,
@@ -16425,6 +16450,31 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
1642516450
let hasCanonicalizer = 1;
1642616451
}
1642716452

16453+
def Torch_AtenMulLeftTOp : Torch_Op<"aten.mul.left_t", [
16454+
AllowsTypeRefinement,
16455+
HasValueSemantics,
16456+
ReadOnly
16457+
]> {
16458+
let summary = "Generated op for `aten::mul.left_t : (t[], int) -> (t[])`";
16459+
let arguments = (ins
16460+
AnyTorchListType:$l,
16461+
Torch_IntType:$n
16462+
);
16463+
let results = (outs
16464+
AnyTorchListType:$result
16465+
);
16466+
let hasCustomAssemblyFormat = 1;
16467+
let extraClassDefinition = [{
16468+
ParseResult AtenMulLeftTOp::parse(OpAsmParser &parser, OperationState &result) {
16469+
return parseDefaultTorchOp(parser, result, 2, 1);
16470+
}
16471+
void AtenMulLeftTOp::print(OpAsmPrinter &printer) {
16472+
printDefaultTorchOp(printer, *this, 2, 1);
16473+
}
16474+
}];
16475+
let hasCanonicalizer = 1;
16476+
}
16477+
1642816478
def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [
1642916479
AllowsTypeRefinement,
1643016480
ReadOnly

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,22 @@ OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
769769
return nullptr;
770770
}
771771

772+
//===----------------------------------------------------------------------===//
773+
// AtenEqBoolOp
774+
//===----------------------------------------------------------------------===//
775+
776+
OpFoldResult AtenEqBoolOp::fold(FoldAdaptor adaptor) {
777+
if (getOperand(0) == getOperand(1))
778+
return IntegerAttr::get(IntegerType::get(getContext(), 1), true);
779+
780+
auto intAttrA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
781+
auto intAttrB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
782+
if (!intAttrA || !intAttrB)
783+
return nullptr;
784+
return IntegerAttr::get(IntegerType::get(getContext(), 1),
785+
intAttrA.getValue() == intAttrB.getValue());
786+
}
787+
772788
//===----------------------------------------------------------------------===//
773789
// AtenNeBoolOp
774790
//===----------------------------------------------------------------------===//
@@ -777,12 +793,12 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
777793
if (getOperand(0) == getOperand(1))
778794
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
779795

780-
bool a, b;
781-
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
782-
return nullptr;
783-
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
796+
auto intAttrA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
797+
auto intAttrB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
798+
if (!intAttrA || !intAttrB)
784799
return nullptr;
785-
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
800+
return IntegerAttr::get(IntegerType::get(getContext(), 1),
801+
intAttrA.getValue() != intAttrB.getValue());
786802
}
787803

788804
//===----------------------------------------------------------------------===//
@@ -1131,6 +1147,35 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
11311147
});
11321148
}
11331149

1150+
//===----------------------------------------------------------------------===//
1151+
// AtenMulLeftTOp
1152+
//===----------------------------------------------------------------------===//
1153+
1154+
void AtenMulLeftTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1155+
MLIRContext *context) {
1156+
// `[1,2] * 3` -> `[1,2,1,2,1,2]`, if it is not mutated.
1157+
patterns.add(+[](AtenMulLeftTOp op, PatternRewriter &rewriter) {
1158+
auto listLiteral = op.getL().getDefiningOp<Torch::PrimListConstructOp>();
1159+
if (!listLiteral || isListPotentiallyMutated(listLiteral))
1160+
return failure();
1161+
1162+
int64_t numReps;
1163+
if (!matchPattern(op.getN(), m_TorchConstantInt(&numReps)))
1164+
return failure();
1165+
1166+
SmallVector<Value> newListElements;
1167+
for (int rep = 0; rep < numReps; ++rep) {
1168+
for (auto operand : listLiteral.getOperands()) {
1169+
newListElements.push_back(operand);
1170+
}
1171+
}
1172+
1173+
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, op.getL().getType(),
1174+
newListElements);
1175+
return success();
1176+
});
1177+
}
1178+
11341179
//===----------------------------------------------------------------------===//
11351180
// AtenMinOtherOp
11361181
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7737,7 +7737,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
77377737
" %3 = torch.prim.If %2 -> (!torch.list<int>) {\n"
77387738
" %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>\n"
77397739
" %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n"
7740-
" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list<int>, !torch.int) -> !torch.list<int> \n"
7740+
" %7 = torch.aten.mul.left_t %5, %6 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
77417741
" %8 = torch.aten.add.t %7, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
77427742
" torch.prim.If.yield %8 : !torch.list<int>\n"
77437743
" } else {\n"
@@ -8948,7 +8948,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
89488948
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
89498949
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
89508950
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
8951-
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
8951+
" %17 = torch.aten.mul.left_t %15, %16 : !torch.list<bool>, !torch.int -> !torch.list<bool>\n"
89528952
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
89538953
" torch.prim.Loop %18, %true, init() {\n"
89548954
" ^bb0(%arg8: !torch.int):\n"
@@ -9812,7 +9812,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
98129812
" %76 = torch.aten.append.t %72, %75 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
98139813
" torch.prim.Loop.condition %true, iter()\n"
98149814
" } : (!torch.int, !torch.bool) -> ()\n"
9815-
" %74 = torch.operator \"aten.add_.t\"(%71, %72) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> \n"
9815+
" %74 = torch.aten.add.t %71, %72 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
98169816
" return %74 : !torch.list<int>\n"
98179817
" }\n"
98189818
" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
@@ -10976,7 +10976,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1097610976
" torch.prim.If.yield %true : !torch.bool\n"
1097710977
" } else {\n"
1097810978
" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional<bool> -> !torch.bool\n"
10979-
" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n"
10979+
" %25 = torch.aten.eq.bool %24, %true : !torch.bool, !torch.bool -> !torch.bool\n"
1098010980
" torch.prim.If.yield %25 : !torch.bool\n"
1098110981
" }\n"
1098210982
" torch.prim.If %17 -> () {\n"
@@ -10995,7 +10995,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1099510995
" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
1099610996
" %23 = torch.prim.If %22 -> (!torch.bool) {\n"
1099710997
" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
10998-
" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n"
10998+
" %25 = torch.aten.eq.bool %24, %false : !torch.bool, !torch.bool -> !torch.bool\n"
1099910999
" torch.prim.If.yield %25 : !torch.bool\n"
1100011000
" } else {\n"
1100111001
" torch.prim.If.yield %false : !torch.bool\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,8 +1757,7 @@ def aten〇col2im〡shape(self: List[int], output_size: List[int], kernel_size:
17571757

17581758
# compute the shape of the output
17591759
num_channels = n_input_plane // (kernel_size[0] * kernel_size[1])
1760-
out: List[int] = [self[0], num_channels] if batch_dim == 0 else [num_channels]
1761-
out += [elem for elem in output_size]
1760+
out: List[int] = ([self[0], num_channels] if batch_dim == 0 else [num_channels]) + [elem for elem in output_size]
17621761

17631762
return out
17641763

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,12 +1134,14 @@ def emit_with_mutating_variants(key, **kwargs):
11341134
emit("aten::gt.float_int : (float, int) -> (bool)")
11351135
emit("aten::pow.int_float : (int, float) -> (float)", has_folder=True)
11361136
emit("aten::__and__.bool : (bool, bool) -> (bool)")
1137+
emit("aten::eq.bool : (bool, bool) -> (bool)", has_folder=True)
11371138
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
11381139
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
11391140
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
11401141
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
11411142
emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True)
11421143
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
1144+
emit("aten::mul.left_t : (t[], int) -> (t[])", has_canonicalizer=True)
11431145
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
11441146
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
11451147
emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True)

test/Dialect/Torch/canonicalize.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,46 @@ func.func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torc
137137
return %0 : !torch.bool
138138
}
139139

140+
// CHECK-LABEL: func.func @torch.aten.eq.bool$same_value() -> !torch.bool {
141+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
142+
// CHECK: return %[[TRUE]] : !torch.bool
143+
func.func @torch.aten.eq.bool$same_value() -> !torch.bool {
144+
%a = torch.constant.bool false
145+
%b = torch.constant.bool false
146+
%0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
147+
return %0 : !torch.bool
148+
}
149+
150+
// CHECK-LABEL: func.func @torch.aten.eq.bool$different_value() -> !torch.bool {
151+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
152+
// CHECK: return %[[FALSE]] : !torch.bool
153+
func.func @torch.aten.eq.bool$different_value() -> !torch.bool {
154+
%a = torch.constant.bool true
155+
%b = torch.constant.bool false
156+
%0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
157+
return %0 : !torch.bool
158+
}
159+
160+
// CHECK-LABEL: func.func @torch.aten.eq.bool$same_operand(
161+
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
162+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
163+
// CHECK: return %[[TRUE]] : !torch.bool
164+
func.func @torch.aten.eq.bool$same_operand(%arg0: !torch.bool) -> !torch.bool {
165+
%0 = torch.aten.eq.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool
166+
return %0 : !torch.bool
167+
}
168+
169+
// CHECK-LABEL: func.func @torch.aten.eq.bool$different_operand(
170+
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
171+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
172+
// CHECK: %[[RET:.*]] = torch.aten.eq.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool
173+
// CHECK: return %[[RET]] : !torch.bool
174+
func.func @torch.aten.eq.bool$different_operand(%a: !torch.bool) -> !torch.bool {
175+
%b = torch.constant.bool false
176+
%0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
177+
return %0 : !torch.bool
178+
}
179+
140180
// CHECK-LABEL: func.func @torch.aten.ne.bool() -> !torch.bool {
141181
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
142182
// CHECK: return %[[TRUE]] : !torch.bool
@@ -698,6 +738,20 @@ func.func @torch.aten.len.t$no_fold_list_mutated() -> !torch.int {
698738
return %2 : !torch.int
699739
}
700740

741+
// CHECK-LABEL: func.func @torch.aten.mul.left_t(
742+
// CHECK: %[[C4:.*]] = torch.constant.int 4
743+
// CHECK: %[[C5:.*]] = torch.constant.int 5
744+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]], %[[C4]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
745+
// CHECK: return %[[LIST]] : !torch.list<int>
746+
func.func @torch.aten.mul.left_t() -> !torch.list<int> {
747+
%int4 = torch.constant.int 4
748+
%int5 = torch.constant.int 5
749+
%int2 = torch.constant.int 2
750+
%0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
751+
%1 = torch.aten.mul.left_t %0, %int2 : !torch.list<int>, !torch.int -> !torch.list<int>
752+
return %1 : !torch.list<int>
753+
}
754+
701755
// CHECK-LABEL: func.func @torch.aten.__getitem__.t(
702756
// CHECK: %[[C5:.*]] = torch.constant.int 5
703757
// CHECK: return %[[C5]] : !torch.int

0 commit comments

Comments
 (0)