Skip to content

Commit eddc09a

Browse files
Gaurav ShuklaGaurav Shukla
authored andcommitted
[TORCH][MLIR] Add E2E support for aten.eq and aten.lt ops
- Added E2E support for `aten.eq.Tensor` and `aten.lt.Tensor` ops. Both the operands are expected to be of the same type, i.e., type promotion is not addressed as a part of this commit. - Added E2E support for `aten.eq.Scalar` and `aten.lt.Scalar` ops. Tensor operand type to Scalar operand type promotion has not been handled in this commit. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 0cd95b5 commit eddc09a

File tree

5 files changed

+368
-9
lines changed

5 files changed

+368
-9
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,194 @@ def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
419419

420420
# ==============================================================================
421421

422+
class ElementwiseLtFloatScalarModule(torch.nn.Module):
423+
def __init__(self):
424+
super().__init__()
425+
426+
@export
427+
@annotate_args([
428+
None,
429+
([-1, -1], torch.float32, True),
430+
])
431+
def forward(self, x):
432+
return torch.lt(x, 0.6)
433+
434+
435+
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
436+
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
437+
module.forward(tu.rand(3, 5))
438+
439+
440+
class ElementwiseLtIntScalarModule(torch.nn.Module):
441+
def __init__(self):
442+
super().__init__()
443+
444+
@export
445+
@annotate_args([
446+
None,
447+
([-1, -1], torch.int64, True),
448+
])
449+
def forward(self, x):
450+
return torch.lt(x, 0)
451+
452+
453+
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
454+
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
455+
module.forward(torch.randint(-10, 15, (3,4)))
456+
457+
458+
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
459+
def __init__(self):
460+
super().__init__()
461+
462+
@export
463+
@annotate_args([
464+
None,
465+
([-1, -1], torch.int32, True),
466+
])
467+
def forward(self, x):
468+
return torch.lt(x, 2)
469+
470+
471+
@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
472+
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
473+
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
474+
475+
476+
class ElementwiseLtFloatTensorModule(torch.nn.Module):
477+
def __init__(self):
478+
super().__init__()
479+
480+
@export
481+
@annotate_args([
482+
None,
483+
([-1, -1], torch.float32, True),
484+
([-1], torch.float32, True),
485+
])
486+
def forward(self, x, y):
487+
return torch.lt(x, y)
488+
489+
490+
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
491+
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
492+
module.forward(tu.rand(3, 5), tu.rand(5))
493+
494+
495+
class ElementwiseLtIntTensorModule(torch.nn.Module):
496+
def __init__(self):
497+
super().__init__()
498+
499+
@export
500+
@annotate_args([
501+
None,
502+
([-1, -1], torch.int64, True),
503+
([-1], torch.int64, True),
504+
])
505+
def forward(self, x, y):
506+
return torch.lt(x, y)
507+
508+
509+
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
510+
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
511+
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))
512+
513+
# ==============================================================================
514+
515+
class ElementwiseEqFloatScalarModule(torch.nn.Module):
516+
def __init__(self):
517+
super().__init__()
518+
519+
@export
520+
@annotate_args([
521+
None,
522+
([-1, -1], torch.float32, True),
523+
])
524+
def forward(self, x):
525+
return torch.eq(x, 6.0)
526+
527+
528+
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
529+
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
530+
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
531+
.to(torch.float32))
532+
533+
534+
class ElementwiseEqIntScalarModule(torch.nn.Module):
535+
def __init__(self):
536+
super().__init__()
537+
538+
@export
539+
@annotate_args([
540+
None,
541+
([-1, -1], torch.int64, True),
542+
])
543+
def forward(self, x):
544+
return torch.eq(x, 2)
545+
546+
547+
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
548+
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
549+
module.forward(torch.randint(2, 4, (5,8)))
550+
551+
552+
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
553+
def __init__(self):
554+
super().__init__()
555+
556+
@export
557+
@annotate_args([
558+
None,
559+
([-1, -1], torch.int32, True),
560+
])
561+
def forward(self, x):
562+
return torch.eq(x, 2)
563+
564+
565+
@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
566+
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
567+
module.forward(torch.randint(2, 4, (5,8)).to(torch.int32))
568+
569+
570+
class ElementwiseEqFloatTensorModule(torch.nn.Module):
571+
def __init__(self):
572+
super().__init__()
573+
574+
@export
575+
@annotate_args([
576+
None,
577+
([-1, -1], torch.float32, True),
578+
([-1], torch.float32, True),
579+
])
580+
def forward(self, x, y):
581+
return torch.eq(x, y)
582+
583+
584+
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
585+
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
586+
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
587+
.to(torch.float32),
588+
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
589+
590+
591+
class ElementwiseEqIntTensorModule(torch.nn.Module):
592+
def __init__(self):
593+
super().__init__()
594+
595+
@export
596+
@annotate_args([
597+
None,
598+
([-1, -1], torch.int64, True),
599+
([-1], torch.int64, True),
600+
])
601+
def forward(self, x, y):
602+
return torch.eq(x, y)
603+
604+
605+
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
606+
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
607+
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5,)))
608+
609+
# ==============================================================================
422610

423611
class ElementwiseClampModule(torch.nn.Module):
424612
def __init__(self):

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,36 @@ def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
570570
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
571571
}
572572

573+
def Torch_AtenLtTensorOp : Torch_Op<"aten.lt.Tensor", [
574+
AllowsTypeRefinement,
575+
HasValueSemantics
576+
]> {
577+
let summary = "Generated op for `aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)`";
578+
let arguments = (ins
579+
AnyTorchTensorType:$self,
580+
AnyTorchTensorType:$other
581+
);
582+
let results = (outs
583+
AnyTorchTensorType:$result
584+
);
585+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
586+
}
587+
588+
def Torch_AtenLt_TensorOp : Torch_Op<"aten.lt_.Tensor", [
589+
IsTrailingUnderscoreInplaceVariant,
590+
AllowsTypeRefinement
591+
]> {
592+
let summary = "Generated op for `aten::lt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
593+
let arguments = (ins
594+
AnyTorchTensorType:$self,
595+
AnyTorchTensorType:$other
596+
);
597+
let results = (outs
598+
AnyTorchTensorType:$result
599+
);
600+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
601+
}
602+
573603
def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
574604
AllowsTypeRefinement,
575605
HasValueSemantics
@@ -844,6 +874,36 @@ def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [
844874
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
845875
}
846876

877+
def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [
878+
AllowsTypeRefinement,
879+
HasValueSemantics
880+
]> {
881+
let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`";
882+
let arguments = (ins
883+
AnyTorchTensorType:$self,
884+
AnyTorchScalarType:$other
885+
);
886+
let results = (outs
887+
AnyTorchTensorType:$result
888+
);
889+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
890+
}
891+
892+
def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [
893+
IsTrailingUnderscoreInplaceVariant,
894+
AllowsTypeRefinement
895+
]> {
896+
let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`";
897+
let arguments = (ins
898+
AnyTorchTensorType:$self,
899+
AnyTorchScalarType:$other
900+
);
901+
let results = (outs
902+
AnyTorchTensorType:$result
903+
);
904+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
905+
}
906+
847907
def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [
848908
AllowsTypeRefinement,
849909
HasValueSemantics

0 commit comments

Comments
 (0)