Skip to content

Commit 63276ed

Browse files
markc-614pytorchmergebot
authored andcommitted
[Inductor] support mixed dtype in the native_layer_norm_backward meta function (pytorch#159830)
Fixes pytorch#159829 Pull Request resolved: pytorch#159830 Approved by: https://github.com/albanD
1 parent dfda2df commit 63276ed

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

test/test_meta.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,6 +1827,43 @@ def test_stride_for_index_Tensor(self):
18271827

18281828
self.assertEqual(out.stride(), f_out.stride())
18291829

1830+
1831+
@parametrize("in_dtype", [torch.float32, torch.float16])
1832+
@parametrize("bias_dtype", [torch.float32, torch.float16, None])
1833+
def test_mixed_dtype_for_native_layer_norm_backward(self, in_dtype, bias_dtype):
1834+
if in_dtype == torch.float16 and bias_dtype == torch.float32:
1835+
self.skipTest(f"not supported input dtype is {in_dtype} and bias dtype is {bias_dtype}")
1836+
device = "meta"
1837+
1838+
def fn(input, weight, bias, need_grad_input):
1839+
outputs = torch.nn.functional.layer_norm(input, input.shape[-1:], weight, bias)
1840+
grad_outs = torch.ones_like(outputs)
1841+
grad_ins = torch.autograd.grad(outputs, need_grad_input, grad_outs)
1842+
return grad_ins
1843+
1844+
input = torch.randn([4, 8, 5], dtype=in_dtype, device=device, requires_grad=True)
1845+
need_grad_input = [input]
1846+
1847+
if bias_dtype:
1848+
weight = torch.randn(
1849+
[5], dtype=bias_dtype, device=device, requires_grad=True
1850+
)
1851+
bias = torch.randn(
1852+
[5], dtype=bias_dtype, device=device, requires_grad=True
1853+
)
1854+
need_grad_input.append(weight)
1855+
need_grad_input.append(bias)
1856+
else:
1857+
weight = None
1858+
bias = None
1859+
1860+
outs = fn(input, weight, bias, need_grad_input)
1861+
out_dtype = [t.dtype for t in outs]
1862+
if bias_dtype:
1863+
self.assertEqual(out_dtype, [in_dtype, bias_dtype, bias_dtype])
1864+
else:
1865+
self.assertEqual(out_dtype, [in_dtype,])
1866+
18301867
instantiate_device_type_tests(TestMeta, globals())
18311868

18321869
def print_op_str_if_not_supported(op_str):

torch/_decomp/decompositions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
17101710

17111711
return (
17121712
_maybe_cast(d_input, input.dtype),
1713-
_maybe_cast(d_weight, input.dtype),
1714-
_maybe_cast(d_bias, input.dtype),
1713+
_maybe_cast(d_weight, weight.dtype if weight is not None else None),
1714+
_maybe_cast(d_bias, bias.dtype if bias is not None else None),
17151715
)
17161716

17171717

0 commit comments

Comments
 (0)