Skip to content

Commit dc7a1ff

Browse files
authored
[Torch] add fold logic for some ops (#3794)
1 parent 6b289f2 commit dc7a1ff

File tree

4 files changed

+254
-2
lines changed

4 files changed

+254
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
34253425
printDefaultTorchOp(printer, *this, 3, 1);
34263426
}
34273427
}];
3428+
let hasFolder = 1;
34283429
let hasCanonicalizer = 1;
34293430
}
34303431

@@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
49024903
printDefaultTorchOp(printer, *this, 3, 1);
49034904
}
49044905
}];
4906+
let hasFolder = 1;
49054907
let hasCanonicalizer = 1;
49064908
}
49074909

@@ -12641,6 +12643,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
1264112643
printDefaultTorchOp(printer, *this, 1, 1);
1264212644
}
1264312645
}];
12646+
let hasFolder = 1;
1264412647
let hasCanonicalizer = 1;
1264512648
}
1264612649

@@ -15334,6 +15337,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
1533415337
printDefaultTorchOp(printer, *this, 2, 1);
1533515338
}
1533615339
}];
15340+
let hasFolder = 1;
1533715341
}
1533815342

1533915343
def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15351535
});
15361536
}
15371537

1538+
// ===----------------------------------------------------------------------===//
1539+
// AtenRSubScalarOp
1540+
// ===----------------------------------------------------------------------===//
1541+
1542+
OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
1543+
auto fpFold = [](llvm::ArrayRef<double> inputs) {
1544+
assert(inputs.size() == 3);
1545+
return inputs[1] - inputs[0] * inputs[2];
1546+
};
1547+
1548+
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
1549+
assert(inputs.size() == 3);
1550+
return inputs[1] - inputs[0] * inputs[2];
1551+
};
1552+
1553+
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
1554+
}
1555+
15381556
//===----------------------------------------------------------------------===//
15391557
// AtenMulTensorOp
15401558
//===----------------------------------------------------------------------===//
@@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
19791997
});
19801998
}
19811999

2000+
// ===----------------------------------------------------------------------===//
2001+
// AtenDivTensorModeOp
2002+
// ===----------------------------------------------------------------------===//
2003+
2004+
OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
2005+
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
2006+
if (!resultTy || !resultTy.hasDtype()) {
2007+
return nullptr;
2008+
}
2009+
std::function<double(ArrayRef<double>)> fpFold;
2010+
std::function<APInt(ArrayRef<APInt>)> intFold;
2011+
2012+
auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
2013+
auto unsign = false;
2014+
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
2015+
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
2016+
}
2017+
2018+
fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
2019+
assert(inputs.size() == 2);
2020+
if (!roundMode) {
2021+
return (double)inputs[0] / inputs[1];
2022+
} else if (roundMode.getValue().str() == "floor") {
2023+
return std::floor((double)inputs[0] / inputs[1]);
2024+
} else {
2025+
return std::trunc((double)inputs[0] / inputs[1]);
2026+
}
2027+
};
2028+
2029+
intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
2030+
assert(inputs.size() == 2);
2031+
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
2032+
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
2033+
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
2034+
int64_t res;
2035+
if (roundMode.getValue().str() == "floor") {
2036+
res = std::floor(lhs / rhs);
2037+
} else {
2038+
res = std::trunc(lhs / rhs);
2039+
}
2040+
return APInt(bits, res);
2041+
};
2042+
2043+
if (!roundMode) {
2044+
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
2045+
fpFold, std::nullopt);
2046+
}
2047+
2048+
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
2049+
fpFold, intFold);
2050+
}
2051+
19822052
//===----------------------------------------------------------------------===//
19832053
// AtenDivScalarModeOp
19842054
//===----------------------------------------------------------------------===//
@@ -3597,6 +3667,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
35973667
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
35983668
}
35993669

3670+
// ===----------------------------------------------------------------------===//
3671+
// AtenRemainderScalarOp
3672+
// ===----------------------------------------------------------------------===//
3673+
3674+
OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
3675+
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
3676+
if (!resultTy || !resultTy.hasDtype()) {
3677+
return nullptr;
3678+
}
3679+
3680+
auto unsign = false;
3681+
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
3682+
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
3683+
}
3684+
auto fpFold = [](llvm::ArrayRef<double> inputs) {
3685+
assert(inputs.size() == 2);
3686+
return std::fmod(inputs[0], inputs[1]);
3687+
};
3688+
3689+
auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
3690+
assert(inputs.size() == 2);
3691+
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
3692+
return ret;
3693+
};
3694+
3695+
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
3696+
}
3697+
36003698
//===----------------------------------------------------------------------===//
36013699
// AtenAddIntOp
36023700
//===----------------------------------------------------------------------===//
@@ -4229,6 +4327,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
42294327
});
42304328
}
42314329

4330+
//===----------------------------------------------------------------------===//
4331+
// AtenIntTensorOp
4332+
//===----------------------------------------------------------------------===//
4333+
4334+
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
4335+
auto value = adaptor.getA();
4336+
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
4337+
if (!dense || !dense.isSplat()) {
4338+
return nullptr;
4339+
}
4340+
4341+
auto splat = dense.getSplatValue<Attribute>();
4342+
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
4343+
auto type = getType();
4344+
if (!isa<mlir::IntegerType>(type)) {
4345+
return nullptr;
4346+
}
4347+
4348+
if (type.isSignlessInteger()) {
4349+
return getI64IntegerAttr(getContext(), intAttr.getInt());
4350+
} else if (type.isSignedInteger()) {
4351+
return getI64IntegerAttr(getContext(), intAttr.getSInt());
4352+
} else {
4353+
return getI64IntegerAttr(getContext(), intAttr.getUInt());
4354+
}
4355+
}
4356+
4357+
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
4358+
return getI64IntegerAttr(
4359+
getContext(),
4360+
static_cast<long>(floatAttr.getValue().convertToDouble()));
4361+
}
4362+
4363+
return nullptr;
4364+
}
4365+
42324366
//===----------------------------------------------------------------------===//
42334367
// AtenFloatTensorOp
42344368
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def emit_with_mutating_variants(key, **kwargs):
379379
# variants.
380380
emit_with_mutating_variants(
381381
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
382+
has_folder=True,
382383
has_canonicalizer=True,
383384
)
384385
emit_with_mutating_variants(
@@ -481,6 +482,7 @@ def emit_with_mutating_variants(key, **kwargs):
481482
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
482483
emit(
483484
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
485+
has_folder=True,
484486
has_canonicalizer=True,
485487
)
486488
emit("aten::gelu : (Tensor, str) -> (Tensor)")
@@ -928,7 +930,9 @@ def emit_with_mutating_variants(key, **kwargs):
928930
emit(
929931
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
930932
)
931-
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
933+
emit(
934+
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
935+
)
932936
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
933937
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
934938
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
@@ -1080,7 +1084,7 @@ def emit_with_mutating_variants(key, **kwargs):
10801084
has_canonicalizer=True,
10811085
)
10821086
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
1083-
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
1087+
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
10841088
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
10851089
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
10861090
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)

test/Dialect/Torch/torch-nary-canonicalize.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
141141
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
142142
return %0 : !torch.vtensor<[4],f32>
143143
}
144+
145+
// -----
146+
147+
// CHECK-LABEL: @fold_aten_rsub_scalar_int
148+
func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> {
149+
// CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
150+
%cst_2 = torch.constant.int 2
151+
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
152+
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
153+
return %0 : !torch.vtensor<[4],si64>
154+
}
155+
156+
// -----
157+
158+
// CHECK-LABEL: @fold_aten_rsub_scalar_float
159+
func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> {
160+
// CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
161+
%cst_2 = torch.constant.float 2.0
162+
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
163+
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32>
164+
return %0 : !torch.vtensor<[4],f32>
165+
}
166+
167+
// -----
168+
169+
// CHECK-LABEL: @fold_aten_remainder_scalar_int
170+
func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> {
171+
// CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
172+
%cst_2 = torch.constant.int 2
173+
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
174+
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
175+
return %0 : !torch.vtensor<[4],si64>
176+
}
177+
178+
// -----
179+
180+
// CHECK-LABEL: @fold_aten_remainder_scalar_float
181+
func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> {
182+
// CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
183+
%cst_2 = torch.constant.float 2.0
184+
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
185+
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
186+
return %0 : !torch.vtensor<[4],f32>
187+
}
188+
189+
// -----
190+
191+
// CHECK-LABEL: @fold_aten_int_tensor_int
192+
func.func @fold_aten_int_tensor_int() -> !torch.int {
193+
// CHECK: %int3 = torch.constant.int 3
194+
%cst_3 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
195+
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int
196+
return %0 : !torch.int
197+
}
198+
199+
// -----
200+
201+
// CHECK-LABEL: @fold_aten_int_tensor_bool
202+
func.func @fold_aten_int_tensor_bool() -> !torch.int {
203+
// CHECK: %int1 = torch.constant.int 1
204+
%cst_false = torch.vtensor.literal(dense<true> : tensor<i1>) : !torch.vtensor<[],i1>
205+
%0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int
206+
return %0 : !torch.int
207+
}
208+
209+
// -----
210+
211+
// CHECK-LABEL: @fold_aten_int_tensor_float
212+
func.func @fold_aten_int_tensor_float() -> !torch.int {
213+
// CHECK: %int3 = torch.constant.int 3
214+
%cst_3 = torch.vtensor.literal(dense<3.1> : tensor<f32>) : !torch.vtensor<[],f32>
215+
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int
216+
return %0 : !torch.int
217+
}
218+
219+
// -----
220+
221+
// CHECK-LABEL: @fold_aten_div_tensor_mode_int
222+
func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> {
223+
// CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
224+
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
225+
%cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
226+
%trunc = torch.constant.str "trunc"
227+
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64>
228+
return %0 : !torch.vtensor<[4],si64>
229+
}
230+
231+
// -----
232+
233+
// CHECK-LABEL: @fold_aten_div_tensor_mode_float
234+
func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> {
235+
// CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
236+
%cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
237+
%cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32>
238+
%floor = torch.constant.str "floor"
239+
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32>
240+
return %0 : !torch.vtensor<[4],f32>
241+
}
242+
243+
// -----
244+
245+
// CHECK-LABEL: @fold_aten_div_tensor_mode_none
246+
func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> {
247+
// CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32>
248+
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
249+
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
250+
%none = torch.constant.none
251+
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32>
252+
return %0 : !torch.vtensor<[4],f32>
253+
}

0 commit comments

Comments
 (0)