Skip to content

Commit ec66c32

Browse files
committed
[AutoBump] Merge with dc7a1ff (Oct 16)
2 parents 8609c42 + dc7a1ff commit ec66c32

File tree

5 files changed

+332
-4
lines changed

5 files changed

+332
-4
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
34253425
printDefaultTorchOp(printer, *this, 3, 1);
34263426
}
34273427
}];
3428+
let hasFolder = 1;
34283429
let hasCanonicalizer = 1;
34293430
}
34303431

@@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
49024903
printDefaultTorchOp(printer, *this, 3, 1);
49034904
}
49044905
}];
4906+
let hasFolder = 1;
49054907
let hasCanonicalizer = 1;
49064908
}
49074909

@@ -12716,6 +12718,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
1271612718
printDefaultTorchOp(printer, *this, 1, 1);
1271712719
}
1271812720
}];
12721+
let hasFolder = 1;
1271912722
let hasCanonicalizer = 1;
1272012723
}
1272112724

@@ -15409,6 +15412,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
1540915412
printDefaultTorchOp(printer, *this, 2, 1);
1541015413
}
1541115414
}];
15415+
let hasFolder = 1;
1541215416
}
1541315417

1541415418
def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15351535
});
15361536
}
15371537

1538+
// ===----------------------------------------------------------------------===//
1539+
// AtenRSubScalarOp
1540+
// ===----------------------------------------------------------------------===//
1541+
1542+
OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
1543+
auto fpFold = [](llvm::ArrayRef<double> inputs) {
1544+
assert(inputs.size() == 3);
1545+
return inputs[1] - inputs[0] * inputs[2];
1546+
};
1547+
1548+
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
1549+
assert(inputs.size() == 3);
1550+
return inputs[1] - inputs[0] * inputs[2];
1551+
};
1552+
1553+
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
1554+
}
1555+
15381556
//===----------------------------------------------------------------------===//
15391557
// AtenMulTensorOp
15401558
//===----------------------------------------------------------------------===//
@@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
19791997
});
19801998
}
19811999

2000+
// ===----------------------------------------------------------------------===//
2001+
// AtenDivTensorModeOp
2002+
// ===----------------------------------------------------------------------===//
2003+
2004+
OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
2005+
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
2006+
if (!resultTy || !resultTy.hasDtype()) {
2007+
return nullptr;
2008+
}
2009+
std::function<double(ArrayRef<double>)> fpFold;
2010+
std::function<APInt(ArrayRef<APInt>)> intFold;
2011+
2012+
auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
2013+
auto unsign = false;
2014+
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
2015+
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
2016+
}
2017+
2018+
fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
2019+
assert(inputs.size() == 2);
2020+
if (!roundMode) {
2021+
return (double)inputs[0] / inputs[1];
2022+
} else if (roundMode.getValue().str() == "floor") {
2023+
return std::floor((double)inputs[0] / inputs[1]);
2024+
} else {
2025+
return std::trunc((double)inputs[0] / inputs[1]);
2026+
}
2027+
};
2028+
2029+
intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
2030+
assert(inputs.size() == 2);
2031+
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
2032+
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
2033+
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
2034+
int64_t res;
2035+
if (roundMode.getValue().str() == "floor") {
2036+
res = std::floor(lhs / rhs);
2037+
} else {
2038+
res = std::trunc(lhs / rhs);
2039+
}
2040+
return APInt(bits, res);
2041+
};
2042+
2043+
if (!roundMode) {
2044+
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
2045+
fpFold, std::nullopt);
2046+
}
2047+
2048+
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
2049+
fpFold, intFold);
2050+
}
2051+
19822052
//===----------------------------------------------------------------------===//
19832053
// AtenDivScalarModeOp
19842054
//===----------------------------------------------------------------------===//
@@ -3612,6 +3682,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
36123682
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
36133683
}
36143684

3685+
// ===----------------------------------------------------------------------===//
3686+
// AtenRemainderScalarOp
3687+
// ===----------------------------------------------------------------------===//
3688+
3689+
OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
3690+
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
3691+
if (!resultTy || !resultTy.hasDtype()) {
3692+
return nullptr;
3693+
}
3694+
3695+
auto unsign = false;
3696+
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
3697+
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
3698+
}
3699+
auto fpFold = [](llvm::ArrayRef<double> inputs) {
3700+
assert(inputs.size() == 2);
3701+
return std::fmod(inputs[0], inputs[1]);
3702+
};
3703+
3704+
auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
3705+
assert(inputs.size() == 2);
3706+
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
3707+
return ret;
3708+
};
3709+
3710+
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
3711+
}
3712+
36153713
//===----------------------------------------------------------------------===//
36163714
// AtenAddIntOp
36173715
//===----------------------------------------------------------------------===//
@@ -4313,6 +4411,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
43134411
});
43144412
}
43154413

4414+
//===----------------------------------------------------------------------===//
4415+
// AtenIntTensorOp
4416+
//===----------------------------------------------------------------------===//
4417+
4418+
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
4419+
auto value = adaptor.getA();
4420+
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
4421+
if (!dense || !dense.isSplat()) {
4422+
return nullptr;
4423+
}
4424+
4425+
auto splat = dense.getSplatValue<Attribute>();
4426+
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
4427+
auto type = getType();
4428+
if (!isa<mlir::IntegerType>(type)) {
4429+
return nullptr;
4430+
}
4431+
4432+
if (type.isSignlessInteger()) {
4433+
return getI64IntegerAttr(getContext(), intAttr.getInt());
4434+
} else if (type.isSignedInteger()) {
4435+
return getI64IntegerAttr(getContext(), intAttr.getSInt());
4436+
} else {
4437+
return getI64IntegerAttr(getContext(), intAttr.getUInt());
4438+
}
4439+
}
4440+
4441+
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
4442+
return getI64IntegerAttr(
4443+
getContext(),
4444+
static_cast<long>(floatAttr.getValue().convertToDouble()));
4445+
}
4446+
4447+
return nullptr;
4448+
}
4449+
43164450
//===----------------------------------------------------------------------===//
43174451
// AtenFloatTensorOp
43184452
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def emit_with_mutating_variants(key, **kwargs):
379379
# variants.
380380
emit_with_mutating_variants(
381381
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
382+
has_folder=True,
382383
has_canonicalizer=True,
383384
)
384385
emit_with_mutating_variants(
@@ -481,6 +482,7 @@ def emit_with_mutating_variants(key, **kwargs):
481482
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
482483
emit(
483484
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
485+
has_folder=True,
484486
has_canonicalizer=True,
485487
)
486488
emit("aten::gelu : (Tensor, str) -> (Tensor)")
@@ -938,7 +940,9 @@ def emit_with_mutating_variants(key, **kwargs):
938940
emit(
939941
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
940942
)
941-
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
943+
emit(
944+
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
945+
)
942946
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
943947
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
944948
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
@@ -1090,7 +1094,7 @@ def emit_with_mutating_variants(key, **kwargs):
10901094
has_canonicalizer=True,
10911095
)
10921096
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
1093-
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
1097+
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
10941098
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
10951099
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
10961100
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)

projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
import torch.utils._pytree as pytree
99
from torch.export.graph_signature import OutputSpec, OutputKind
1010
from torch.export import ExportedProgram
11+
from torch._dynamo.backends.common import aot_autograd
1112

1213
from torch_mlir import fx
1314
from torch_mlir_e2e_test.configs.utils import (
1415
recursively_convert_to_numpy,
1516
recursively_convert_from_numpy,
1617
)
1718
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
19+
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
1820

1921

2022
def refine_result_type(_result):
@@ -31,17 +33,91 @@ def refine_result_type(_result):
3133
class FxImporterTestConfig(TestConfig):
3234
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
3335

34-
def __init__(self, backend, output_type="linalg-on-tensors"):
36+
def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False):
3537
super().__init__()
3638
self._backend = backend
39+
self._torch_compile = torch_compile
3740
self._output_type = output_type
3841

3942
def compile(
4043
self, program: torch.nn.Module, verbose: bool = False
4144
) -> torch.nn.Module:
4245
return program
4346

44-
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
47+
def run(self, artifact: torch.nn.Module, trace: Trace):
48+
return (
49+
self._export_run(artifact, trace)
50+
if not self._torch_compile
51+
else self._stateless_run(artifact, trace)
52+
)
53+
54+
def _stateless_run(self, artifact: torch.nn.Module, trace: Trace):
55+
dynamic_argument_pos = None
56+
dynamic_dim_pos = None
57+
annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
58+
for i, annotation in enumerate(annotations):
59+
if i == 0: # Skip the "self" annotation.
60+
continue
61+
if not annotation[2]:
62+
raise ValueError(
63+
"Can only compile inputs annotated as having value semantics."
64+
)
65+
for dim_i, dim in enumerate(annotation[0]):
66+
if dim == -1:
67+
dynamic_argument_pos = i - 1
68+
dynamic_dim_pos = dim_i
69+
break
70+
if dynamic_argument_pos is not None:
71+
break
72+
result: Trace = []
73+
for item in trace:
74+
75+
def _base_backend(gm: torch.fx.GraphModule, example_inputs):
76+
for node in gm.graph.nodes:
77+
if node.op == "placeholder":
78+
if (
79+
isinstance(node.meta["val"], torch.SymInt)
80+
and not node.users
81+
):
82+
gm.graph.erase_node(node)
83+
module = fx.stateless_fx_import(
84+
gm,
85+
output_type=self._output_type,
86+
model_name=artifact.__class__.__name__,
87+
)
88+
module = self._backend.compile(module)
89+
backend_module = self._backend.load(module)
90+
91+
def invoke_func(*torch_inputs):
92+
torch_inputs = [
93+
x
94+
for x in filter(
95+
lambda i: isinstance(i, torch.Tensor), torch_inputs
96+
)
97+
]
98+
with torch.no_grad():
99+
numpy_inputs = recursively_convert_to_numpy(torch_inputs)
100+
return recursively_convert_from_numpy(
101+
getattr(backend_module, artifact.__class__.__name__)(
102+
*numpy_inputs
103+
)
104+
)
105+
106+
return invoke_func
107+
108+
fw_compiler = aot_autograd(fw_compiler=_base_backend)
109+
if dynamic_argument_pos is not None:
110+
torch._dynamo.mark_dynamic(
111+
item.inputs[dynamic_argument_pos], dynamic_dim_pos
112+
)
113+
module = torch.compile(artifact, backend=fw_compiler)
114+
outputs = module(*item.inputs)
115+
result.append(
116+
TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs)
117+
)
118+
return result
119+
120+
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
45121
result: Trace = []
46122
for item in trace:
47123
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))

0 commit comments

Comments
 (0)