@@ -834,6 +834,52 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
834
834
# ==============================================================================
835
835
836
836
837
+ class ElementwiseReluBFloat16Module (torch .nn .Module ):
838
+ def __init__ (self ):
839
+ super ().__init__ ()
840
+
841
+ @export
842
+ @annotate_args (
843
+ [
844
+ None ,
845
+ ([- 1 , - 1 ], torch .bfloat16 , True ),
846
+ ]
847
+ )
848
+ def forward (self , x ):
849
+ return torch .relu (x )
850
+
851
+
852
+ @register_test_case (module_factory = lambda : ElementwiseReluBFloat16Module ())
853
+ def ElementwiseReluModule_bfloat16 (module , tu : TestUtils ):
854
+ module .forward (tu .rand (4 , 2 , low = - 1 ).to (torch .bfloat16 ))
855
+
856
+
857
+ # ==============================================================================
858
+
859
+
860
+ class ElementwiseReluFloat16Module (torch .nn .Module ):
861
+ def __init__ (self ):
862
+ super ().__init__ ()
863
+
864
+ @export
865
+ @annotate_args (
866
+ [
867
+ None ,
868
+ ([- 1 , - 1 ], torch .float16 , True ),
869
+ ]
870
+ )
871
+ def forward (self , x ):
872
+ return torch .relu (x )
873
+
874
+
875
+ @register_test_case (module_factory = lambda : ElementwiseReluFloat16Module ())
876
+ def ElementwiseReluModule_float16 (module , tu : TestUtils ):
877
+ module .forward (tu .rand (4 , 2 , low = - 1 ).to (torch .float16 ))
878
+
879
+
880
+ # ==============================================================================
881
+
882
+
837
883
class QuantizedReluInt8 (torch .nn .Module ):
838
884
def __init__ (self ):
839
885
super ().__init__ ()
@@ -1769,6 +1815,62 @@ def ElementwiseClampModule_basic(module, tu: TestUtils):
1769
1815
# ==============================================================================
1770
1816
1771
1817
1818
+ class ElementwiseClampBFloat16Module (torch .nn .Module ):
1819
+ def __init__ (self ):
1820
+ super ().__init__ ()
1821
+
1822
+ @export
1823
+ @annotate_args (
1824
+ [
1825
+ None ,
1826
+ ([- 1 , - 1 ], torch .bfloat16 , True ),
1827
+ ]
1828
+ )
1829
+ def forward (self , x ):
1830
+ float_min = torch .clamp (x , min = - 2.0 )
1831
+ int_min = torch .clamp (x , min = - 3 )
1832
+ float_max = torch .clamp (x , max = 2.0 )
1833
+ int_max = torch .clamp (x , max = 3 )
1834
+ both = torch .clamp (x , min = - 5 , max = 5 )
1835
+ return float_min , int_min , float_max , int_max , both
1836
+
1837
+
1838
+ @register_test_case (module_factory = lambda : ElementwiseClampBFloat16Module ())
1839
+ def ElementwiseClampModule_bfloat16 (module , tu : TestUtils ):
1840
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .bfloat16 ))
1841
+
1842
+
1843
+ # ==============================================================================
1844
+
1845
+
1846
+ class ElementwiseClampFloat16Module (torch .nn .Module ):
1847
+ def __init__ (self ):
1848
+ super ().__init__ ()
1849
+
1850
+ @export
1851
+ @annotate_args (
1852
+ [
1853
+ None ,
1854
+ ([- 1 , - 1 ], torch .float16 , True ),
1855
+ ]
1856
+ )
1857
+ def forward (self , x ):
1858
+ float_min = torch .clamp (x , min = - 2.0 )
1859
+ int_min = torch .clamp (x , min = - 3 )
1860
+ float_max = torch .clamp (x , max = 2.0 )
1861
+ int_max = torch .clamp (x , max = 3 )
1862
+ both = torch .clamp (x , min = - 5 , max = 5 )
1863
+ return float_min , int_min , float_max , int_max , both
1864
+
1865
+
1866
+ @register_test_case (module_factory = lambda : ElementwiseClampFloat16Module ())
1867
+ def ElementwiseClampModule_float16 (module , tu : TestUtils ):
1868
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .float16 ))
1869
+
1870
+
1871
+ # ==============================================================================
1872
+
1873
+
1772
1874
class ElementwiseClampMinModule (torch .nn .Module ):
1773
1875
def __init__ (self ):
1774
1876
super ().__init__ ()
@@ -1795,6 +1897,58 @@ def ElementwiseClampMinModule_basic(module, tu: TestUtils):
1795
1897
# ==============================================================================
1796
1898
1797
1899
1900
+ class ElementwiseClampMinBFloat16Module (torch .nn .Module ):
1901
+ def __init__ (self ):
1902
+ super ().__init__ ()
1903
+
1904
+ @export
1905
+ @annotate_args (
1906
+ [
1907
+ None ,
1908
+ ([- 1 , - 1 ], torch .bfloat16 , True ),
1909
+ ]
1910
+ )
1911
+ def forward (self , x ):
1912
+ float_min = torch .ops .aten .clamp_min (x , min = - 2.0 )
1913
+ int_min = torch .ops .aten .clamp_min (x , min = 2 )
1914
+ min = torch .ops .aten .clamp_min (x , min = 11.0 )
1915
+ return float_min , int_min , min
1916
+
1917
+
1918
+ @register_test_case (module_factory = lambda : ElementwiseClampMinBFloat16Module ())
1919
+ def ElementwiseClampMinModule_bfloat16 (module , tu : TestUtils ):
1920
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .bfloat16 ))
1921
+
1922
+
1923
+ # ==============================================================================
1924
+
1925
+
1926
+ class ElementwiseClampMinFloat16Module (torch .nn .Module ):
1927
+ def __init__ (self ):
1928
+ super ().__init__ ()
1929
+
1930
+ @export
1931
+ @annotate_args (
1932
+ [
1933
+ None ,
1934
+ ([- 1 , - 1 ], torch .float16 , True ),
1935
+ ]
1936
+ )
1937
+ def forward (self , x ):
1938
+ float_min = torch .ops .aten .clamp_min (x , min = - 2.0 )
1939
+ int_min = torch .ops .aten .clamp_min (x , min = 2 )
1940
+ min = torch .ops .aten .clamp_min (x , min = 11.0 )
1941
+ return float_min , int_min , min
1942
+
1943
+
1944
+ @register_test_case (module_factory = lambda : ElementwiseClampMinFloat16Module ())
1945
+ def ElementwiseClampMinModule_float16 (module , tu : TestUtils ):
1946
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .float16 ))
1947
+
1948
+
1949
+ # ==============================================================================
1950
+
1951
+
1798
1952
class ElementwiseClampMaxModule (torch .nn .Module ):
1799
1953
def __init__ (self ):
1800
1954
super ().__init__ ()
@@ -1821,6 +1975,58 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils):
1821
1975
# ==============================================================================
1822
1976
1823
1977
1978
+ class ElementwiseClampMaxBFloat16Module (torch .nn .Module ):
1979
+ def __init__ (self ):
1980
+ super ().__init__ ()
1981
+
1982
+ @export
1983
+ @annotate_args (
1984
+ [
1985
+ None ,
1986
+ ([- 1 , - 1 ], torch .bfloat16 , True ),
1987
+ ]
1988
+ )
1989
+ def forward (self , x ):
1990
+ float_max = torch .ops .aten .clamp_max (x , max = 2.0 )
1991
+ int_max = torch .ops .aten .clamp_max (x , max = 3 )
1992
+ max = torch .ops .aten .clamp_max (x , max = - 11.0 )
1993
+ return float_max , int_max , max
1994
+
1995
+
1996
+ @register_test_case (module_factory = lambda : ElementwiseClampMaxBFloat16Module ())
1997
+ def ElementwiseClampMaxModule_bfloat16 (module , tu : TestUtils ):
1998
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .bfloat16 ))
1999
+
2000
+
2001
+ # ==============================================================================
2002
+
2003
+
2004
+ class ElementwiseClampMaxFloat16Module (torch .nn .Module ):
2005
+ def __init__ (self ):
2006
+ super ().__init__ ()
2007
+
2008
+ @export
2009
+ @annotate_args (
2010
+ [
2011
+ None ,
2012
+ ([- 1 , - 1 ], torch .float16 , True ),
2013
+ ]
2014
+ )
2015
+ def forward (self , x ):
2016
+ float_max = torch .ops .aten .clamp_max (x , max = 2.0 )
2017
+ int_max = torch .ops .aten .clamp_max (x , max = 3 )
2018
+ max = torch .ops .aten .clamp_max (x , max = - 11.0 )
2019
+ return float_max , int_max , max
2020
+
2021
+
2022
+ @register_test_case (module_factory = lambda : ElementwiseClampMaxFloat16Module ())
2023
+ def ElementwiseClampMaxModule_float16 (module , tu : TestUtils ):
2024
+ module .forward (tu .rand (3 , 5 , low = - 10 , high = 10 ).to (torch .float16 ))
2025
+
2026
+
2027
+ # ==============================================================================
2028
+
2029
+
1824
2030
class ElementwiseClampTensorFloatModule (torch .nn .Module ):
1825
2031
def __init__ (self ):
1826
2032
super ().__init__ ()
0 commit comments