Skip to content

Commit 0f95851

Browse files
authored
[Torch] Add support for aten.round.decimals op (#4166)
* Added decomposition to aten.round * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
1 parent 1f437a9 commit 0f95851

File tree

8 files changed

+150
-0
lines changed

8 files changed

+150
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4562,6 +4562,54 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
45624562
}];
45634563
}
45644564

4565+
def Torch_AtenRoundDecimalsOp : Torch_Op<"aten.round.decimals", [
4566+
AllowsTypeRefinement,
4567+
HasValueSemantics,
4568+
ReadOnly
4569+
]> {
4570+
let summary = "Generated op for `aten::round.decimals : (Tensor, int) -> (Tensor)`";
4571+
let arguments = (ins
4572+
AnyTorchTensorType:$self,
4573+
Torch_IntType:$decimals
4574+
);
4575+
let results = (outs
4576+
AnyTorchOptionalTensorType:$result
4577+
);
4578+
let hasCustomAssemblyFormat = 1;
4579+
let extraClassDefinition = [{
4580+
ParseResult AtenRoundDecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4581+
return parseDefaultTorchOp(parser, result, 2, 1);
4582+
}
4583+
void AtenRoundDecimalsOp::print(OpAsmPrinter &printer) {
4584+
printDefaultTorchOp(printer, *this, 2, 1);
4585+
}
4586+
}];
4587+
let hasFolder = 1;
4588+
}
4589+
4590+
def Torch_AtenRound_DecimalsOp : Torch_Op<"aten.round_.decimals", [
4591+
IsTrailingUnderscoreInplaceVariant,
4592+
AllowsTypeRefinement
4593+
]> {
4594+
let summary = "Generated op for `aten::round_.decimals : (Tensor, int) -> (Tensor)`";
4595+
let arguments = (ins
4596+
Torch_NonValueTensorType:$self,
4597+
Torch_IntType:$decimals
4598+
);
4599+
let results = (outs
4600+
AnyTorchOptionalNonValueTensorType:$result
4601+
);
4602+
let hasCustomAssemblyFormat = 1;
4603+
let extraClassDefinition = [{
4604+
ParseResult AtenRound_DecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4605+
return parseDefaultTorchOp(parser, result, 2, 1);
4606+
}
4607+
void AtenRound_DecimalsOp::print(OpAsmPrinter &printer) {
4608+
printDefaultTorchOp(printer, *this, 2, 1);
4609+
}
4610+
}];
4611+
}
4612+
45654613
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
45664614
AllowsTypeRefinement,
45674615
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,19 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
19921992
return {};
19931993
}
19941994

1995+
//===----------------------------------------------------------------------===//
1996+
// AtenRoundDecimalsOp
1997+
//===----------------------------------------------------------------------===//
1998+
1999+
OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {
2000+
auto resultType = dyn_cast<ValueTensorType>(getType());
2001+
if (resultType && resultType.hasDtype() &&
2002+
isa<mlir::IntegerType>(resultType.getDtype())) {
2003+
return getSelf();
2004+
}
2005+
return {};
2006+
}
2007+
19952008
//===----------------------------------------------------------------------===//
19962009
// AtenRoundOp
19972010
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6754,6 +6754,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67546754
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67556755
" return %0 : !torch.list<int>\n"
67566756
" }\n"
6757+
" func.func @\"__torch_mlir_shape_fn.aten.round.decimals\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6758+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6759+
" return %0 : !torch.list<int>\n"
6760+
" }\n"
67576761
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
67586762
" %none = torch.constant.none\n"
67596763
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
@@ -13007,6 +13011,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1300713011
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1300813012
" return %0#1 : !torch.int\n"
1300913013
" }\n"
13014+
" func.func @\"__torch_mlir_dtype_fn.aten.round.decimals\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
13015+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13016+
" return %0#1 : !torch.int\n"
13017+
" }\n"
1301013018
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
1301113019
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1301213020
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11976,6 +11976,54 @@ class DecomposeAten_AssertScalarOp
1197611976
};
1197711977
} // namespace
1197811978

11979+
namespace {
11980+
class DecomposeAtenRoundDecimalsOp
11981+
: public OpRewritePattern<AtenRoundDecimalsOp> {
11982+
public:
11983+
using OpRewritePattern<AtenRoundDecimalsOp>::OpRewritePattern;
11984+
LogicalResult matchAndRewrite(AtenRoundDecimalsOp op,
11985+
PatternRewriter &rewriter) const override {
11986+
// AtenRoundDecimalsOp is decomposed as follows if the decimals value is
11987+
// non-zero: scale = 10 ** decimals return round(x * scale) / scale
11988+
// otherwise:
11989+
// return round(x)
11990+
11991+
auto loc = op.getLoc();
11992+
auto input = op.getSelf();
11993+
auto inputType = cast<BaseTensorType>(input.getType());
11994+
11995+
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
11996+
return rewriter.notifyMatchFailure(
11997+
op, "unimplemented: non-floating point dtype");
11998+
}
11999+
12000+
int64_t decimals;
12001+
if (!matchPattern(op.getDecimals(), m_TorchConstantInt(&decimals))) {
12002+
return rewriter.notifyMatchFailure(
12003+
op, "non-constant decimal point is not supported.");
12004+
}
12005+
12006+
Value newOp = op->getOperand(0);
12007+
Value scale;
12008+
if (decimals) {
12009+
auto scaleVal = pow(10, decimals);
12010+
scale = rewriter.create<ConstantFloatOp>(
12011+
loc, rewriter.getF64FloatAttr(scaleVal));
12012+
newOp = rewriter.create<AtenMulScalarOp>(loc, op.getType(), input, scale);
12013+
}
12014+
12015+
newOp = rewriter.create<AtenRoundOp>(loc, op.getType(), newOp);
12016+
12017+
if (decimals) {
12018+
newOp = rewriter.create<AtenDivScalarOp>(loc, op.getType(), newOp, scale);
12019+
}
12020+
12021+
rewriter.replaceOp(op, newOp);
12022+
return success();
12023+
}
12024+
};
12025+
} // namespace
12026+
1197912027
namespace {
1198012028
class DecomposeComplexOpsPass
1198112029
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12291,6 +12339,7 @@ class DecomposeComplexOpsPass
1229112339
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
1229212340
patterns);
1229312341
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12342+
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
1229412343

1229512344
GreedyRewriteConfig config;
1229612345
config.setUseTopDownTraversal(true);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
384384
target.addIllegalOp<AtenHstackOp>();
385385
target.addIllegalOp<AtenColumnStackOp>();
386386
target.addIllegalOp<AtenRollOp>();
387+
target.addIllegalOp<AtenRoundDecimalsOp>();
387388
target.addIllegalOp<AtenRepeatOp>();
388389
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
389390
target.addIllegalOp<AtenExpandOp>();

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
352352
def aten〇round〡shape(self: List[int]) -> List[int]:
353353
return upstream_shape_functions.unary(self)
354354

355+
def aten〇round〇decimals〡shape(self: List[int], decimals: int) -> List[int]:
356+
return upstream_shape_functions.unary(self)
357+
355358
def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
356359
if dim < 0:
357360
dim += len(self)
@@ -3674,6 +3677,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
36743677
self_rank, self_dtype = self_rank_dtype
36753678
return self_dtype
36763679

3680+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, decimals=0))
3681+
def aten〇round〇decimals〡dtype(self_rank_dtype: Tuple[int, int], decimals: int) -> int:
3682+
self_rank, self_dtype = self_rank_dtype
3683+
return self_dtype
3684+
36773685
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
36783686
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
36793687
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ def emit_with_mutating_variants(key, **kwargs):
451451
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454+
emit_with_mutating_variants(
455+
"aten::round.decimals : (Tensor, int) -> (Tensor)", has_folder=True
456+
)
454457
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455458
emit_with_mutating_variants("aten::fix : (Tensor) -> (Tensor)")
456459
emit("aten::special_expm1 : (Tensor) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6570,6 +6570,26 @@ def AtenRoundIntModule_basic(module, tu: TestUtils):
65706570
module.forward(tu.randint(5, 5, low=-10))
65716571

65726572

6573+
class AtenRoundFloatDecimalsModule(torch.nn.Module):
6574+
def __init__(self):
6575+
super().__init__()
6576+
6577+
@export
6578+
@annotate_args(
6579+
[
6580+
None,
6581+
([-1, -1], torch.float32, True),
6582+
]
6583+
)
6584+
def forward(self, x):
6585+
return torch.ops.aten.round(x, decimals=2)
6586+
6587+
6588+
@register_test_case(module_factory=lambda: AtenRoundFloatDecimalsModule())
6589+
def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
6590+
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
6591+
6592+
65736593
# ==============================================================================
65746594

65756595

0 commit comments

Comments
 (0)