Skip to content

Commit 03ba03e

Browse files
authored
[Relax][PyTorch] Cleanup tests for ExportedProgram frontend (#17822)
* move gelu, relu, selu, sigmoid, silu tests to test_basic_unary_ops * remove unused torchversion * we don't need to manually call `test_*` functions * remove unused variable
1 parent bf61216 commit 03ba03e

File tree

1 file changed

+7
-179
lines changed

1 file changed

+7
-179
lines changed

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 7 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
from tvm.script import relax as R
2828
from tvm.script import tir as T
2929
from tvm.relax.frontend.torch import from_exported_program
30-
from packaging import version
31-
32-
torch_version = torch.__version__
3330

3431

3532
def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None):
@@ -56,10 +53,17 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No
5653
(torch.erf, R.erf),
5754
(torch.exp, R.exp),
5855
(torch.floor, R.floor),
56+
(torch.ops.aten.gelu, R.nn.gelu),
5957
(torch.log, R.log),
6058
(torch.neg, R.negative),
59+
(torch.relu, R.nn.relu),
60+
(torch.relu_, R.nn.relu),
6161
(torch.round, R.round),
6262
(torch.rsqrt, R.rsqrt),
63+
(torch.selu, R.nn.selu),
64+
(torch.sigmoid, R.sigmoid),
65+
(torch.ops.aten.silu, R.nn.silu),
66+
(torch.ops.aten.silu_, R.nn.silu),
6367
(torch.sin, R.sin),
6468
(torch.sinh, R.sinh),
6569
(torch.sign, R.sign),
@@ -314,35 +318,6 @@ def main(
314318
verify_model(Elu(), example_args, {}, expected_elu)
315319
verify_model(Elu2(), example_args, {}, expected_elu)
316320

317-
# gelu
318-
class Gelu(Module):
319-
def __init__(self):
320-
super().__init__()
321-
self.gelu = torch.nn.GELU()
322-
323-
def forward(self, input):
324-
return self.gelu(input)
325-
326-
class Gelu2(Module):
327-
def forward(self, input):
328-
return torch.nn.functional.gelu(input)
329-
330-
@tvm.script.ir_module
331-
class expected_gelu:
332-
@R.function
333-
def main(
334-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
335-
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
336-
# block 0
337-
with R.dataflow():
338-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1)
339-
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
340-
R.output(gv)
341-
return gv
342-
343-
verify_model(Gelu(), example_args, {}, expected_gelu)
344-
verify_model(Gelu2(), example_args, {}, expected_gelu)
345-
346321
# hardsigmoid
347322
class Hardsigmoid(torch.nn.Module):
348323
def __init__(self):
@@ -413,15 +388,6 @@ def main(
413388
verify_model(Hardswish2(), example_args, {}, expected1)
414389
verify_model(Hardswish3(), example_args, {}, expected1)
415390

416-
# hardtanh
417-
test_hardtanh()
418-
419-
# leakyrelu
420-
test_leakyrelu()
421-
422-
# softplus
423-
test_softplus()
424-
425391
# log2
426392
class Log2(Module):
427393
def forward(self, x):
@@ -487,9 +453,6 @@ def main(
487453

488454
verify_model(Log1p(), example_args, {}, Expected_log1p)
489455

490-
# log_softmax
491-
test_logsoftmax()
492-
493456
# reciprocal
494457
class Reciprocal(Module):
495458
def forward(self, input):
@@ -511,140 +474,6 @@ def main(
511474

512475
verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
513476

514-
# relu
515-
class ReLU0(Module):
516-
def __init__(self):
517-
super().__init__()
518-
self.relu = torch.nn.ReLU()
519-
520-
def forward(self, input):
521-
return self.relu(input)
522-
523-
class ReLU1(Module):
524-
def forward(self, input):
525-
return torch.nn.functional.relu(input)
526-
527-
class ReLU2(Module):
528-
def forward(self, input):
529-
return torch.ops.aten.relu_(input)
530-
531-
@tvm.script.ir_module
532-
class expected_relu:
533-
@R.function
534-
def main(
535-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
536-
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
537-
# block 0
538-
with R.dataflow():
539-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
540-
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
541-
R.output(gv)
542-
return gv
543-
544-
verify_model(ReLU0(), example_args, {}, expected_relu)
545-
verify_model(ReLU1(), example_args, {}, expected_relu)
546-
verify_model(ReLU2(), example_args, {}, expected_relu)
547-
548-
# selu
549-
class Selu1(Module):
550-
def __init__(self):
551-
super().__init__()
552-
self.selu = torch.nn.SELU()
553-
554-
def forward(self, input):
555-
return self.selu(input)
556-
557-
class Selu2(Module):
558-
def forward(self, input):
559-
return torch.nn.functional.selu(input)
560-
561-
@tvm.script.ir_module
562-
class expected_selu:
563-
@R.function
564-
def main(
565-
input: R.Tensor((1, 3, 10, 10), dtype="float32")
566-
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
567-
with R.dataflow():
568-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input)
569-
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
570-
R.output(gv)
571-
return gv
572-
573-
verify_model(Selu1(), example_args, {}, expected_selu)
574-
verify_model(Selu2(), example_args, {}, expected_selu)
575-
576-
# sigmoid
577-
class Sigmoid(Module):
578-
def __init__(self):
579-
super().__init__()
580-
self.sigmoid = torch.nn.Sigmoid()
581-
582-
def forward(self, input):
583-
return self.sigmoid(input)
584-
585-
class Sigmoid2(Module):
586-
def forward(self, input):
587-
return torch.sigmoid(input)
588-
589-
@tvm.script.ir_module
590-
class expected_sigmoid:
591-
@R.function
592-
def main(
593-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
594-
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
595-
# block 0
596-
with R.dataflow():
597-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1)
598-
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
599-
R.output(gv)
600-
return gv
601-
602-
verify_model(Sigmoid(), example_args, {}, expected_sigmoid)
603-
verify_model(Sigmoid2(), example_args, {}, expected_sigmoid)
604-
605-
# silu
606-
class SiLU(Module):
607-
def __init__(self):
608-
super().__init__()
609-
self.silu = torch.nn.SiLU()
610-
611-
def forward(self, input):
612-
return self.silu(input)
613-
614-
class SiLU2(Module):
615-
def forward(self, input):
616-
return torch.nn.functional.silu(input)
617-
618-
class SiLU3(Module):
619-
def forward(self, input):
620-
return torch.ops.aten.silu_(input)
621-
622-
@tvm.script.ir_module
623-
class expected_silu:
624-
@R.function
625-
def main(
626-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
627-
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
628-
# block 0
629-
with R.dataflow():
630-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1)
631-
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
632-
R.output(gv)
633-
return gv
634-
635-
verify_model(SiLU(), example_args, {}, expected_silu)
636-
verify_model(SiLU2(), example_args, {}, expected_silu)
637-
verify_model(SiLU3(), example_args, {}, expected_silu)
638-
639-
# softmax
640-
test_softmax()
641-
642-
# softshrink
643-
test_softshrink()
644-
645-
# tril, triu
646-
test_tril_triu()
647-
648477

649478
def test_hardtanh():
650479
class Hardtanh(torch.nn.Module):
@@ -1044,7 +873,6 @@ def test_binary3():
1044873
torch.randn(10, 10, dtype=torch.float32),
1045874
torch.randn(10, 10, dtype=torch.float32),
1046875
)
1047-
example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
1048876

1049877
# Max
1050878
class Max1(Module):

0 commit comments

Comments
 (0)