Skip to content

Commit 8d8d2c2

Browse files
[MLIR][TORCH] Add E2E support for aten.div.Scalar
This commit adds lowering of `aten.div.Scalar`. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 56c6e36 commit 8d8d2c2

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,21 @@ def forward(self, a):
479479
@register_test_case(module_factory=lambda: ElementwiseRsqrtModule())
480480
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
481481
module.forward(tu.rand(3, 4))
482+
483+
# ==============================================================================
484+
class ElementwiseDivScalarModule(torch.nn.Module):
485+
def __init__(self):
486+
super().__init__()
487+
488+
@export
489+
@annotate_args([
490+
None,
491+
([-1, -1], torch.float32, True),
492+
])
493+
def forward(self, x):
494+
return torch.div(x, 10.0)
495+
496+
497+
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
498+
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
499+
module.forward(tu.rand(3, 4))

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
15941594
Value result = convertScalarToDtype(b, loc, input, dtype);
15951595
return result;
15961596
}
1597+
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
1598+
Type dtype = converter->convertType(divScalar.getType())
1599+
.cast<RankedTensorType>()
1600+
.getElementType();
1601+
if (!dtype.isa<mlir::FloatType>()) {
1602+
divScalar.emitError("unimplemented: non-floating point dtype");
1603+
return nullptr;
1604+
}
1605+
Value self = payloadArgs[0];
1606+
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
1607+
return b.create<arith::DivFOp>(loc, self, other);
1608+
}
15971609

15981610
op->emitError("unimplemented lowering in "
15991611
"createLinalgPayloadCalculationForElementwiseOp");
@@ -1805,7 +1817,8 @@ struct ConvertElementwiseOp : ConversionPattern {
18051817
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
18061818
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
18071819
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
1808-
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>(op))
1820+
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp>(
1821+
op))
18091822
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
18101823

18111824
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))

0 commit comments

Comments
 (0)