Skip to content

Commit 6dabf18

Browse files
author
Prashant Kumar
committed
Add support for int types in gtScalar op.
Support for integer types in gtScalar op has been added. The code share same logic with gtTensor op and can be merged which is added as a TODO.
1 parent 8d4879f commit 6dabf18

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
326326
# ==============================================================================
327327

328328

329-
class ElementwiseGtScalarModule(torch.nn.Module):
329+
class ElementwiseGtFloatScalarModule(torch.nn.Module):
330330
def __init__(self):
331331
super().__init__()
332332

@@ -339,10 +339,47 @@ def forward(self, x):
339339
return torch.gt(x, 0.6)
340340

341341

342-
@register_test_case(module_factory=lambda: ElementwiseGtScalarModule())
343-
def ElementwiseGtScalarModule_basic(module, tu: TestUtils):
342+
@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
343+
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
344344
module.forward(tu.rand(3, 5))
345345

346+
347+
class ElementwiseGtIntScalarModule(torch.nn.Module):
348+
def __init__(self):
349+
super().__init__()
350+
351+
@export
352+
@annotate_args([
353+
None,
354+
([-1, -1], torch.int64, True),
355+
])
356+
def forward(self, x):
357+
return torch.gt(x, 10)
358+
359+
360+
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
361+
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
362+
module.forward(torch.randint(-10, 15, (3,4)))
363+
364+
365+
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
366+
def __init__(self):
367+
super().__init__()
368+
369+
@export
370+
@annotate_args([
371+
None,
372+
([-1, -1], torch.int32, True),
373+
])
374+
def forward(self, x):
375+
return torch.gt(x, 7)
376+
377+
378+
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
379+
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
380+
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
381+
382+
346383
class ElementwiseGtFloatTensorModule(torch.nn.Module):
347384
def __init__(self):
348385
super().__init__()
@@ -361,6 +398,7 @@ def forward(self, x, y):
361398
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
362399
module.forward(tu.rand(3, 5), tu.rand(5))
363400

401+
364402
class ElementwiseGtIntTensorModule(torch.nn.Module):
365403
def __init__(self):
366404
super().__init__()

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,14 +1735,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
17351735
}
17361736

17371737
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
1738-
Type dtype = gtScalar.self().getType().cast<ValueTensorType>().getDtype();
1739-
if (!dtype.isa<mlir::FloatType>()) {
1740-
gtScalar.emitError("unimplemented: non-floating point operand dtype");
1741-
return nullptr;
1738+
Type dtype = gtScalar.self().getType().cast<BaseTensorType>().getDtype();
1739+
1740+
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
1741+
// one static function.
1742+
Value otherPromoted =
1743+
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
1744+
1745+
if (dtype.isa<mlir::FloatType>())
1746+
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1747+
payloadArgs[0], otherPromoted);
1748+
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
1749+
if (!operands[1].getType().isa<mlir::IntegerType>()) {
1750+
// TODO: Promote tensor args from integer to float.
1751+
gtScalar.emitError(
1752+
"unimplemented: type promotion from tensor to scalar.");
1753+
return nullptr;
1754+
}
1755+
1756+
if (intType.isUnsigned())
1757+
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1758+
payloadArgs[0], otherPromoted);
1759+
if (intType.isSigned())
1760+
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
1761+
payloadArgs[0], otherPromoted);
17421762
}
1743-
Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
1744-
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1745-
payloadArgs[0], otherPromoted);
1763+
gtScalar.emitError("unimplemented: dtype isn't supported.");
1764+
return nullptr;
17461765
}
17471766

17481767
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {

0 commit comments

Comments
 (0)