Skip to content

Commit a778f99

Browse files
Gaurav ShuklaGaurav Shukla
authored andcommitted
[TORCH][MLIR] Add E2E support for aten.ceil op
This commit adds lowering of `aten.ceil` op as a part of element-wise ops lowering. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 03b6edc commit a778f99

File tree

5 files changed

+52
-4
lines changed

5 files changed

+52
-4
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,22 @@ def forward(self, a):
521521
def ElementwiseFloorModule_basic(module, tu: TestUtils):
522522
module.forward(tu.rand(3, 4))
523523

524+
class ElementwiseCeilModule(torch.nn.Module):
525+
def __init__(self):
526+
super().__init__()
527+
@export
528+
@annotate_args([
529+
None,
530+
([-1, -1], torch.float32, True),
531+
])
532+
533+
def forward(self, a):
534+
return torch.ceil(a)
535+
536+
@register_test_case(module_factory=lambda: ElementwiseCeilModule())
537+
def ElementwiseCeilModule_basic(module, tu: TestUtils):
538+
module.forward(tu.rand(3, 4))
539+
524540
class ElementwisePowModule(torch.nn.Module):
525541
def __init__(self):
526542
super().__init__()

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,34 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
298298
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
299299
}
300300

301+
def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
302+
AllowsTypeRefinement,
303+
HasValueSemantics
304+
]> {
305+
let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`";
306+
let arguments = (ins
307+
AnyTorchTensorType:$self
308+
);
309+
let results = (outs
310+
AnyTorchTensorType:$result
311+
);
312+
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
313+
}
314+
315+
def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [
316+
IsTrailingUnderscoreInplaceVariant,
317+
AllowsTypeRefinement
318+
]> {
319+
let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`";
320+
let arguments = (ins
321+
AnyTorchTensorType:$self
322+
);
323+
let results = (outs
324+
AnyTorchTensorType:$result
325+
);
326+
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
327+
}
328+
301329
def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [
302330
AllowsTypeRefinement,
303331
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
15121512
return b.create<math::ExpOp>(loc, payloadArgs[0]);
15131513
if (isa<AtenFloorOp>(op))
15141514
return b.create<math::FloorOp>(loc, payloadArgs[0]);
1515+
if (isa<AtenCeilOp>(op))
1516+
return b.create<math::CeilOp>(loc, payloadArgs[0]);
15151517
if (isa<AtenLogOp>(op))
15161518
return b.create<math::LogOp>(loc, payloadArgs[0]);
15171519
if (isa<AtenSqrtOp>(op))
@@ -2067,7 +2069,8 @@ struct ConvertElementwiseOp : ConversionPattern {
20672069
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
20682070
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
20692071
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
2070-
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op))
2072+
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
2073+
AtenCeilOp>(op))
20712074
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
20722075

20732076
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -3635,8 +3638,8 @@ class ConvertTorchToLinalg
36353638
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
36363639
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
36373640
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
3638-
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
3639-
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
3641+
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
3642+
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
36403643
AtenWhereSelfOp>();
36413644
patterns.add<ConvertElementwiseOp>(typeConverter, context);
36423645
target.addIllegalOp<AtenSqueezeOp>();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
242242
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
243243
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
244244
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
245-
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) {
245+
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
246246
return getLatticeElement(op->getResult(0)).join(*operands[0]);
247247
}
248248

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def emit_with_mutating_variants(key, **kwargs):
447447
"aten::cos : (Tensor) -> (Tensor)",
448448
"aten::neg : (Tensor) -> (Tensor)",
449449
"aten::floor : (Tensor) -> (Tensor)",
450+
"aten::ceil : (Tensor) -> (Tensor)",
450451
"aten::bitwise_not : (Tensor) -> (Tensor)",
451452
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
452453
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",

0 commit comments

Comments
 (0)