From f4c2dd57acfe5382ff8d12c072df9bc2bc8a6fef Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 11 Sep 2025 19:37:48 +0000 Subject: [PATCH] [ROCm] fix miopen batchnorm changing output format (#162112) It was found that the integration of miopen batchnorm was causing the output to always be in default contig memory format even when the input was channels last. This also unskips a number of related unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162112 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily Co-authored-by: Dmitry Nikolaev Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- aten/src/ATen/native/Normalization.cpp | 4 ++- .../ATen/native/miopen/BatchNorm_miopen.cpp | 34 +++++++++---------- test/functorch/test_ops.py | 14 -------- test/nn/test_convolution.py | 29 +++++----------- test/test_nn.py | 17 ++++++++-- tools/autograd/derivatives.yaml | 2 +- 6 files changed, 45 insertions(+), 55 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 7327bf2d7e30b..13b421d1e6888 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -624,7 +624,9 @@ std::tuple _batch_norm_impl_index( if (backend == BatchNormBackend::Miopen) { return std::tuple_cat( at::miopen_batch_norm( - input.contiguous(), weight.contiguous(), bias.contiguous(), + input.contiguous(input.suggest_memory_format()), + weight.contiguous(), + bias.contiguous(), running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps), diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index af69dfc76e571..0c122c9e13d4d 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -7,6 +7,7 @@ #include #else #include +#include #include #include #endif @@ -102,7 +103,7 @@ std::tuple miopen_batch_norm( mode = miopenBNSpatial; } - auto output_t = at::empty(input->sizes(), input->options()); + auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format()); TensorArg output{ output_t, "output", 0 }; auto handle = getMiopenHandle(); @@ -170,20 +171,15 @@ std::tuple miopen_batch_norm_backward( const std::optional& save_var_t_opt, double epsilon) { // See [Note: hacky wrapper removal for optional tensor] - const Tensor& running_mean = - running_mean_opt.value_or(Tensor()); - const Tensor& running_var = - running_var_opt.value_or(Tensor()); - const Tensor& save_mean_t = - save_mean_t_opt.value_or(Tensor()); - const Tensor& save_var_t = - save_var_t_opt.value_or(Tensor()); - - TensorArg input{ input_t, "input", 1 }, - grad_output{ grad_output_t, "grad_output", 2 }, - weight{ weight_t, "weight", 3 }, - save_mean{ save_mean_t, "save_mean", 4 }, - save_var{ save_var_t, "save_var", 5 }; + const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor()); + const Tensor& save_var_t = save_var_t_opt.value_or(Tensor()); + + auto grad_output_contig = + grad_output_t.contiguous(input_t.suggest_memory_format()); + TensorArg input{input_t, "input", 1}, + grad_output{grad_output_contig, "grad_output", 2}, + weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4}, + save_var{save_var_t, "save_var", 5}; CheckedFrom c = "miopen_batch_norm_backward"; checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); @@ -195,7 +191,11 @@ std::tuple miopen_batch_norm_backward( } checkAllSameType(c, {input, grad_output}); checkAllSameType(c, {weight, save_mean, save_var}); - checkAllContiguous(c, {input, grad_output, save_mean, save_var}); + // TODO: is weight required to be contiguous? + checkAllContiguous(c, {save_mean, save_var}); + // TODO: TensorArg check should start handle memory format + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); checkDimRange(c, input, 2, 6 /* exclusive */); checkSameSize(c, input, grad_output); auto num_features = input->size(1); @@ -210,7 +210,7 @@ std::tuple miopen_batch_norm_backward( mode = miopenBNSpatial; } - auto grad_input_t = at::empty(input->sizes(), input->options()); + auto grad_input_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); auto grad_weight_t = at::empty(weight->sizes(), weight->options()); auto grad_bias_t = at::empty(weight->sizes(), weight->options()); diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 78e64278cb1e2..a2c88f7c35a13 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -468,13 +468,6 @@ class TestOperators(TestCase): ), # Works on ROCm xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), - # RuntimeError: Expected contiguous tensor, but got - # non-contiguous tensor for argument #2 'grad_output' - decorate( - "_batch_norm_with_update", - decorator=expectedFailureIf(TEST_WITH_ROCM), - device_type="cuda", - ), } ), ) @@ -2400,13 +2393,6 @@ def fn(input, weight, bias): skip("sparse.sampled_addmm", ""), skip("sparse.mm", "reduce"), skip("native_layer_norm", "", device_type="cpu"), - # RuntimeError: Expected contiguous tensor, but got - # non-contiguous tensor for argument #2 'grad_output' - decorate( - "_batch_norm_with_update", - decorator=expectedFailureIf(TEST_WITH_ROCM), - device_type="cuda", - ), }, ) @opsToleranceOverride( diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 7dacfeed003cc..2e378b77968b5 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -30,7 +30,6 @@ skipCUDAIfMiopen, skipCUDAIfNoCudnn, skipCUDAIfNoMiopen, - skipCUDAIfNotMiopenSuggestNHWC, skipCUDAIfRocm, skipMeta, skipMPS, @@ -52,9 +51,7 @@ parametrize as parametrize_test, run_tests, set_default_dtype, - skipIfNotMiopenSuggestNHWC, skipIfRocmArch, - skipIfRocmVersionLessThan, subtest, TEST_SCIPY, TEST_WITH_ROCM, @@ -67,6 +64,7 @@ if TEST_WITH_ROCM: os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1" if TEST_SCIPY: @@ -718,7 +716,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self): # Almost identical to the above `test_Conv2d_naive_groups` @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) - @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias(self): dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: @@ -764,7 +761,6 @@ def test_Conv2d_groups_nobias(self): # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) - @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias_v2(self): torch.manual_seed(123) dev_dtypes = [("cpu", torch.float)] @@ -899,7 +895,6 @@ def test_conv_tbc(self): @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDNN, "needs cudnn") - @skipIfNotMiopenSuggestNHWC def test_grouped_conv_cudnn_nhwc_support(self): # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to( @@ -3149,7 +3144,6 @@ def test_conv_noncontig_weights_and_bias(self, device): @onlyCUDA @largeTensorTest("12GB") - @skipIfRocmVersionLessThan((6, 0)) def test_conv_transposed_large(self, device): dtype = torch.half if self.device_type == "cuda" else torch.float conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) @@ -3193,7 +3187,6 @@ def test_conv_transposed_large(self, device): self.assertEqual(maxdiff3, 0) @onlyCUDA - @skipCUDAIfRocm @largeTensorTest("12GB") def test_conv_large(self, device): dtype = torch.half if self.device_type == "cuda" else torch.float @@ -3226,7 +3219,6 @@ def test_conv_large(self, device): self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) @onlyCUDA - @skipCUDAIfRocm @largeTensorTest("20GB", "cpu") @largeTensorTest("60GB", "cuda") def test_conv_large_batch_1(self, device): @@ -3363,7 +3355,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device): @dtypes(torch.float) @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) - @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) @@ -3632,19 +3623,21 @@ def helper( ) @onlyCUDA - @skipCUDAIfNotMiopenSuggestNHWC @dtypes(torch.half, torch.float, torch.cfloat) def test_conv_cudnn_nhwc(self, device, dtype): def helper(n, c, h, w, out_channels, kernel_size, groups): - input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( - memory_format=torch.channels_last - ) + # randint with dtype=torch.cfloat fails with + # RuntimeError: check_random_bounds handles only integral, floating-point and boolean types + # must create randint and randint_like using default int64, then cast to desired + input = torch.randint( + -3, 3, (n, c, h, w), dtype=torch.int64, device=device + ).to(dtype, memory_format=torch.channels_last) input.requires_grad_() conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to( device="cuda", dtype=dtype, memory_format=torch.channels_last ) for p in conv.parameters(): - p.data = torch.randint_like(p, -3, 3) + p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype) # use FP64 channels-first conv as reference ref_input = input.detach().clone().contiguous().double().requires_grad_() @@ -3658,7 +3651,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups): out = conv(input) ref_out = ref_conv(ref_input) - grad = torch.randint_like(out, -3, 3) + grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype) ref_grad = grad.detach().clone().double().contiguous() out.backward(grad) @@ -3685,7 +3678,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups): helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.half, torch.float) def test_conv_cudnn_ndhwc(self, device, dtype): def helper(n, c, d, h, w, out_channels, kernel_size, groups): @@ -3815,7 +3807,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): ) @onlyCUDA - @skipCUDAIfNotMiopenSuggestNHWC @tf32_on_and_off(0.05) def test_conv_cudnn_mismatch_memory_format(self, device): configs = [ @@ -3949,7 +3940,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype): self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) @onlyCUDA - @skipCUDAIfRocm def test_convert_conv2d_weight_memory_format(self, device): input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float() @@ -3969,7 +3959,6 @@ def test_convert_conv2d_weight_memory_format(self, device): self.assertTrue(out.is_contiguous(memory_format=memory_format)) @onlyCUDA - @skipCUDAIfRocm def test_convert_conv3d_weight_memory_format(self, device): input = torch.randint( 1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device diff --git a/test/test_nn.py b/test/test_nn.py index 0c84d6ffe129e..5092e36a3e01f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -62,6 +62,7 @@ if TEST_WITH_ROCM: os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1" # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -3513,7 +3514,6 @@ def test_cudnn_forward_exception(self): self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), @@ -3521,7 +3521,8 @@ def test_cudnn_weight_format(self): nn.GRU(10, 20, batch_first=True), nn.RNN(10, 20, batch_first=True) ] - first_warn = True + # ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it + first_warn = False if torch.version.hip else True for rnn in rnns: rnn.cuda() input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") @@ -5170,6 +5171,18 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): ("NCHW", "native", False, torch.float), ("NCHW", "native", True, torch.half), ("NCHW", "native", True, torch.bfloat16), + + ("NHWC", "cpu", False, torch.float), + ("NHWC", "cpu", True, torch.half), + ("NHWC", "cpu", True, torch.bfloat16), + + ("NHWC", "native", False, torch.float), + ("NHWC", "native", True, torch.half), + ("NHWC", "native", True, torch.bfloat16), + + ("NHWC", "NCHW", False, torch.float), + ("NHWC", "NCHW", True, torch.half), + ("NHWC", "NCHW", True, torch.bfloat16), ], name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}" ) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c050c6cbdc4c3..506d829b5712c 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2801,7 +2801,7 @@ self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) - input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)