@@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
635
635
module .forward (tu .rand (1 , 2 , 1 , 3 ), tu .rand (2 ), tu .rand (2 ))
636
636
637
637
638
+ # ==============================================================================
639
+ class RMSNormModule (torch .nn .Module ):
640
+ def __init__ (self ):
641
+ super ().__init__ ()
642
+
643
+ @export
644
+ @annotate_args (
645
+ [
646
+ None ,
647
+ ([8 , 9 , 1 , 2 , 4 ], torch .float32 , True ),
648
+ ([1 , 2 , 4 ], torch .float32 , True ),
649
+ ]
650
+ )
651
+ def forward (self , x , weight ):
652
+ list = [1 , 2 , 4 ]
653
+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.5 )
654
+
655
+
656
+ @register_test_case (module_factory = lambda : RMSNormModule ())
657
+ def RMSNormModule_basic (module , tu : TestUtils ):
658
+ module .forward (tu .rand (8 , 9 , 1 , 2 , 4 ), tu .rand (1 , 2 , 4 ))
659
+
660
+
661
+ class RMSNormWithoutEpsModule (torch .nn .Module ):
662
+ def __init__ (self ):
663
+ super ().__init__ ()
664
+
665
+ @export
666
+ @annotate_args (
667
+ [
668
+ None ,
669
+ ([2 , 5 , 2 , 2 , 3 ], torch .float32 , True ),
670
+ ([2 , 2 , 3 ], torch .float32 , True ),
671
+ ]
672
+ )
673
+ def forward (self , x , weight ):
674
+ list = [2 , 2 , 3 ]
675
+ return torch .ops .aten .rms_norm (x , list , weight )
676
+
677
+
678
+ @register_test_case (module_factory = lambda : RMSNormWithoutEpsModule ())
679
+ def RMSNormWithoutEpsModule_basic (module , tu : TestUtils ):
680
+ module .forward (tu .rand (2 , 5 , 2 , 2 , 3 ), tu .rand (2 , 2 , 3 ))
681
+
682
+
683
+ class RMSNormWithoutWeightModule (torch .nn .Module ):
684
+ def __init__ (self ):
685
+ super ().__init__ ()
686
+
687
+ @export
688
+ @annotate_args (
689
+ [
690
+ None ,
691
+ ([1 , 2 , 3 , 4 ], torch .float32 , True ),
692
+ ]
693
+ )
694
+ def forward (self , x ):
695
+ list = [4 ]
696
+ return torch .ops .aten .rms_norm (x , list , eps = 0.5 )
697
+
698
+
699
+ @register_test_case (module_factory = lambda : RMSNormWithoutWeightModule ())
700
+ def RMSNormWithoutWeightModule_basic (module , tu : TestUtils ):
701
+ module .forward (tu .rand (1 , 2 , 3 , 4 ))
702
+
703
+
704
+ class RMSNormAllNormalizeModule (torch .nn .Module ):
705
+ def __init__ (self ):
706
+ super ().__init__ ()
707
+
708
+ @export
709
+ @annotate_args (
710
+ [None , ([5 , 6 , 3 ], torch .float32 , True ), ([5 , 6 , 3 ], torch .float32 , True )]
711
+ )
712
+ def forward (self , x , weight ):
713
+ list = [5 , 6 , 3 ]
714
+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.7 )
715
+
716
+
717
+ @register_test_case (module_factory = lambda : RMSNormAllNormalizeModule ())
718
+ def RMSNormAllNormalizeModule_basic (module , tu : TestUtils ):
719
+ module .forward (tu .rand (5 , 6 , 3 ), tu .rand (5 , 6 , 3 ))
720
+
721
+
722
+ class RMSNormDynamicModule (torch .nn .Module ):
723
+ def __init__ (self ):
724
+ super ().__init__ ()
725
+
726
+ @export
727
+ @annotate_args (
728
+ [
729
+ None ,
730
+ ([- 1 , - 1 , - 1 , - 1 ], torch .float32 , True ),
731
+ ([- 1 , - 1 , - 1 ], torch .float32 , True ),
732
+ ]
733
+ )
734
+ def forward (self , x , weight ):
735
+ list = [2 , 3 , 4 ]
736
+ return torch .ops .aten .rms_norm (x , list , weight , eps = 0.8 )
737
+
738
+
739
+ @register_test_case (module_factory = lambda : RMSNormDynamicModule ())
740
+ def RMSNormDynamicModule_basic (module , tu : TestUtils ):
741
+ module .forward (tu .rand (1 , 2 , 3 , 4 ), tu .rand (2 , 3 , 4 ))
742
+
743
+
638
744
# ==============================================================================
639
745
class RenormModuleFloat32 (torch .nn .Module ):
640
746
def __init__ (self ):
0 commit comments