2727from tvm .script import relax as R
2828from tvm .script import tir as T
2929from tvm .relax .frontend .torch import from_exported_program
30- from packaging import version
31-
32- torch_version = torch .__version__
3330
3431
3532def verify_model (torch_model , example_args , binding , expected , dynamic_shapes = None ):
@@ -56,10 +53,17 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No
5653 (torch .erf , R .erf ),
5754 (torch .exp , R .exp ),
5855 (torch .floor , R .floor ),
56+ (torch .ops .aten .gelu , R .nn .gelu ),
5957 (torch .log , R .log ),
6058 (torch .neg , R .negative ),
59+ (torch .relu , R .nn .relu ),
60+ (torch .relu_ , R .nn .relu ),
6161 (torch .round , R .round ),
6262 (torch .rsqrt , R .rsqrt ),
63+ (torch .selu , R .nn .selu ),
64+ (torch .sigmoid , R .sigmoid ),
65+ (torch .ops .aten .silu , R .nn .silu ),
66+ (torch .ops .aten .silu_ , R .nn .silu ),
6367 (torch .sin , R .sin ),
6468 (torch .sinh , R .sinh ),
6569 (torch .sign , R .sign ),
@@ -314,35 +318,6 @@ def main(
314318 verify_model (Elu (), example_args , {}, expected_elu )
315319 verify_model (Elu2 (), example_args , {}, expected_elu )
316320
317- # gelu
318- class Gelu (Module ):
319- def __init__ (self ):
320- super ().__init__ ()
321- self .gelu = torch .nn .GELU ()
322-
323- def forward (self , input ):
324- return self .gelu (input )
325-
326- class Gelu2 (Module ):
327- def forward (self , input ):
328- return torch .nn .functional .gelu (input )
329-
330- @tvm .script .ir_module
331- class expected_gelu :
332- @R .function
333- def main (
334- input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
335- ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
336- # block 0
337- with R .dataflow ():
338- lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .gelu (input_1 )
339- gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv ,)
340- R .output (gv )
341- return gv
342-
343- verify_model (Gelu (), example_args , {}, expected_gelu )
344- verify_model (Gelu2 (), example_args , {}, expected_gelu )
345-
346321 # hardsigmoid
347322 class Hardsigmoid (torch .nn .Module ):
348323 def __init__ (self ):
@@ -413,15 +388,6 @@ def main(
413388 verify_model (Hardswish2 (), example_args , {}, expected1 )
414389 verify_model (Hardswish3 (), example_args , {}, expected1 )
415390
416- # hardtanh
417- test_hardtanh ()
418-
419- # leakyrelu
420- test_leakyrelu ()
421-
422- # softplus
423- test_softplus ()
424-
425391 # log2
426392 class Log2 (Module ):
427393 def forward (self , x ):
@@ -487,9 +453,6 @@ def main(
487453
488454 verify_model (Log1p (), example_args , {}, Expected_log1p )
489455
490- # log_softmax
491- test_logsoftmax ()
492-
493456 # reciprocal
494457 class Reciprocal (Module ):
495458 def forward (self , input ):
@@ -511,140 +474,6 @@ def main(
511474
512475 verify_model (Reciprocal (), example_args , {}, expected_reciprocal )
513476
514- # relu
515- class ReLU0 (Module ):
516- def __init__ (self ):
517- super ().__init__ ()
518- self .relu = torch .nn .ReLU ()
519-
520- def forward (self , input ):
521- return self .relu (input )
522-
523- class ReLU1 (Module ):
524- def forward (self , input ):
525- return torch .nn .functional .relu (input )
526-
527- class ReLU2 (Module ):
528- def forward (self , input ):
529- return torch .ops .aten .relu_ (input )
530-
531- @tvm .script .ir_module
532- class expected_relu :
533- @R .function
534- def main (
535- input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
536- ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
537- # block 0
538- with R .dataflow ():
539- lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .relu (input_1 )
540- gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv ,)
541- R .output (gv )
542- return gv
543-
544- verify_model (ReLU0 (), example_args , {}, expected_relu )
545- verify_model (ReLU1 (), example_args , {}, expected_relu )
546- verify_model (ReLU2 (), example_args , {}, expected_relu )
547-
548- # selu
549- class Selu1 (Module ):
550- def __init__ (self ):
551- super ().__init__ ()
552- self .selu = torch .nn .SELU ()
553-
554- def forward (self , input ):
555- return self .selu (input )
556-
557- class Selu2 (Module ):
558- def forward (self , input ):
559- return torch .nn .functional .selu (input )
560-
561- @tvm .script .ir_module
562- class expected_selu :
563- @R .function
564- def main (
565- input : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
566- ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
567- with R .dataflow ():
568- lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .selu (input )
569- gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv ,)
570- R .output (gv )
571- return gv
572-
573- verify_model (Selu1 (), example_args , {}, expected_selu )
574- verify_model (Selu2 (), example_args , {}, expected_selu )
575-
576- # sigmoid
577- class Sigmoid (Module ):
578- def __init__ (self ):
579- super ().__init__ ()
580- self .sigmoid = torch .nn .Sigmoid ()
581-
582- def forward (self , input ):
583- return self .sigmoid (input )
584-
585- class Sigmoid2 (Module ):
586- def forward (self , input ):
587- return torch .sigmoid (input )
588-
589- @tvm .script .ir_module
590- class expected_sigmoid :
591- @R .function
592- def main (
593- input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
594- ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
595- # block 0
596- with R .dataflow ():
597- lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .sigmoid (input_1 )
598- gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv ,)
599- R .output (gv )
600- return gv
601-
602- verify_model (Sigmoid (), example_args , {}, expected_sigmoid )
603- verify_model (Sigmoid2 (), example_args , {}, expected_sigmoid )
604-
605- # silu
606- class SiLU (Module ):
607- def __init__ (self ):
608- super ().__init__ ()
609- self .silu = torch .nn .SiLU ()
610-
611- def forward (self , input ):
612- return self .silu (input )
613-
614- class SiLU2 (Module ):
615- def forward (self , input ):
616- return torch .nn .functional .silu (input )
617-
618- class SiLU3 (Module ):
619- def forward (self , input ):
620- return torch .ops .aten .silu_ (input )
621-
622- @tvm .script .ir_module
623- class expected_silu :
624- @R .function
625- def main (
626- input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
627- ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
628- # block 0
629- with R .dataflow ():
630- lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .silu (input_1 )
631- gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv ,)
632- R .output (gv )
633- return gv
634-
635- verify_model (SiLU (), example_args , {}, expected_silu )
636- verify_model (SiLU2 (), example_args , {}, expected_silu )
637- verify_model (SiLU3 (), example_args , {}, expected_silu )
638-
639- # softmax
640- test_softmax ()
641-
642- # softshrink
643- test_softshrink ()
644-
645- # tril, triu
646- test_tril_triu ()
647-
648477
649478def test_hardtanh ():
650479 class Hardtanh (torch .nn .Module ):
@@ -1044,7 +873,6 @@ def test_binary3():
1044873 torch .randn (10 , 10 , dtype = torch .float32 ),
1045874 torch .randn (10 , 10 , dtype = torch .float32 ),
1046875 )
1047- example_args2 = (torch .randn (10 , 10 , dtype = torch .float32 ),)
1048876
1049877 # Max
1050878 class Max1 (Module ):
0 commit comments