@@ -1827,6 +1827,43 @@ def test_stride_for_index_Tensor(self):
1827
1827
1828
1828
self .assertEqual (out .stride (), f_out .stride ())
1829
1829
1830
+
1831
+ @parametrize ("in_dtype" , [torch .float32 , torch .float16 ])
1832
+ @parametrize ("bias_dtype" , [torch .float32 , torch .float16 , None ])
1833
+ def test_mixed_dtype_for_native_layer_norm_backward (self , in_dtype , bias_dtype ):
1834
+ if in_dtype == torch .float16 and bias_dtype == torch .float32 :
1835
+ self .skipTest (f"not supported input dtype is { in_dtype } and bias dtype is { bias_dtype } " )
1836
+ device = "meta"
1837
+
1838
+ def fn (input , weight , bias , need_grad_input ):
1839
+ outputs = torch .nn .functional .layer_norm (input , input .shape [- 1 :], weight , bias )
1840
+ grad_outs = torch .ones_like (outputs )
1841
+ grad_ins = torch .autograd .grad (outputs , need_grad_input , grad_outs )
1842
+ return grad_ins
1843
+
1844
+ input = torch .randn ([4 , 8 , 5 ], dtype = in_dtype , device = device , requires_grad = True )
1845
+ need_grad_input = [input ]
1846
+
1847
+ if bias_dtype :
1848
+ weight = torch .randn (
1849
+ [5 ], dtype = bias_dtype , device = device , requires_grad = True
1850
+ )
1851
+ bias = torch .randn (
1852
+ [5 ], dtype = bias_dtype , device = device , requires_grad = True
1853
+ )
1854
+ need_grad_input .append (weight )
1855
+ need_grad_input .append (bias )
1856
+ else :
1857
+ weight = None
1858
+ bias = None
1859
+
1860
+ outs = fn (input , weight , bias , need_grad_input )
1861
+ out_dtype = [t .dtype for t in outs ]
1862
+ if bias_dtype :
1863
+ self .assertEqual (out_dtype , [in_dtype , bias_dtype , bias_dtype ])
1864
+ else :
1865
+ self .assertEqual (out_dtype , [in_dtype ,])
1866
+
1830
1867
instantiate_device_type_tests (TestMeta , globals ())
1831
1868
1832
1869
def print_op_str_if_not_supported (op_str ):
0 commit comments