Skip to content

Commit bd4bf5b

Browse files
jeffdailyjerrymannil
authored andcommitted
[release/2.9] fix miopen batchnorm changing output format
cherry pick of pytorch#162112
1 parent f363ae8 commit bd4bf5b

File tree

3 files changed

+23
-37
lines changed

3 files changed

+23
-37
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/NativeFunctions.h>
88
#else
99
#include <ATen/ops/empty.h>
10+
#include <ATen/ops/empty_like.h>
1011
#include <ATen/ops/miopen_batch_norm_native.h>
1112
#include <ATen/ops/miopen_batch_norm_backward_native.h>
1213
#endif
@@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
102103
mode = miopenBNSpatial;
103104
}
104105

105-
auto output_t = at::empty(input->sizes(), input->options());
106+
auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format());
106107
TensorArg output{ output_t, "output", 0 };
107108

108109
auto handle = getMiopenHandle();
@@ -170,20 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
170171
const std::optional<Tensor>& save_var_t_opt,
171172
double epsilon) {
172173
// See [Note: hacky wrapper removal for optional tensor]
173-
const Tensor& running_mean =
174-
running_mean_opt.value_or(Tensor());
175-
const Tensor& running_var =
176-
running_var_opt.value_or(Tensor());
177-
const Tensor& save_mean_t =
178-
save_mean_t_opt.value_or(Tensor());
179-
const Tensor& save_var_t =
180-
save_var_t_opt.value_or(Tensor());
181-
182-
TensorArg input{ input_t, "input", 1 },
183-
grad_output{ grad_output_t, "grad_output", 2 },
184-
weight{ weight_t, "weight", 3 },
185-
save_mean{ save_mean_t, "save_mean", 4 },
186-
save_var{ save_var_t, "save_var", 5 };
174+
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
175+
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
176+
177+
auto grad_output_contig =
178+
grad_output_t.contiguous(input_t.suggest_memory_format());
179+
TensorArg input{input_t, "input", 1},
180+
grad_output{grad_output_contig, "grad_output", 2},
181+
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
182+
save_var{save_var_t, "save_var", 5};
187183
CheckedFrom c = "miopen_batch_norm_backward";
188184

189185
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});

test/nn/test_convolution.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
skipCUDAIfMiopen,
3131
skipCUDAIfNoCudnn,
3232
skipCUDAIfNoMiopen,
33-
skipCUDAIfNotMiopenSuggestNHWC,
3433
skipCUDAIfRocm,
3534
skipMeta,
3635
skipMPS,
@@ -52,9 +51,7 @@
5251
parametrize as parametrize_test,
5352
run_tests,
5453
set_default_dtype,
55-
skipIfNotMiopenSuggestNHWC,
5654
skipIfRocmArch,
57-
skipIfRocmVersionLessThan,
5855
subtest,
5956
TEST_SCIPY,
6057
TEST_WITH_ROCM,
@@ -67,6 +64,7 @@
6764

6865
if TEST_WITH_ROCM:
6966
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
67+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
7068

7169

7270
if TEST_SCIPY:
@@ -718,7 +716,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
718716
# Almost identical to the above `test_Conv2d_naive_groups`
719717
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
720718
@tf32_on_and_off(0.001)
721-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
722719
def test_Conv2d_groups_nobias(self):
723720
dev_dtypes = [("cpu", torch.float)]
724721
if TEST_CUDA:
@@ -764,7 +761,6 @@ def test_Conv2d_groups_nobias(self):
764761
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
765762
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
766763
@tf32_on_and_off(0.001)
767-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
768764
def test_Conv2d_groups_nobias_v2(self):
769765
torch.manual_seed(123)
770766
dev_dtypes = [("cpu", torch.float)]
@@ -899,7 +895,6 @@ def test_conv_tbc(self):
899895

900896
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
901897
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
902-
@skipIfNotMiopenSuggestNHWC
903898
def test_grouped_conv_cudnn_nhwc_support(self):
904899
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
905900
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
@@ -3149,7 +3144,6 @@ def test_conv_noncontig_weights_and_bias(self, device):
31493144

31503145
@onlyCUDA
31513146
@largeTensorTest("12GB")
3152-
@skipIfRocmVersionLessThan((6, 0))
31533147
def test_conv_transposed_large(self, device):
31543148
dtype = torch.half if self.device_type == "cuda" else torch.float
31553149
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
@@ -3193,7 +3187,6 @@ def test_conv_transposed_large(self, device):
31933187
self.assertEqual(maxdiff3, 0)
31943188

31953189
@onlyCUDA
3196-
@skipCUDAIfRocm
31973190
@largeTensorTest("12GB")
31983191
def test_conv_large(self, device):
31993192
dtype = torch.half if self.device_type == "cuda" else torch.float
@@ -3226,7 +3219,6 @@ def test_conv_large(self, device):
32263219
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
32273220

32283221
@onlyCUDA
3229-
@skipCUDAIfRocm
32303222
@largeTensorTest("20GB", "cpu")
32313223
@largeTensorTest("60GB", "cuda")
32323224
def test_conv_large_batch_1(self, device):
@@ -3363,7 +3355,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
33633355
@dtypes(torch.float)
33643356
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
33653357
@tf32_on_and_off(0.001)
3366-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
33673358
def test_Conv2d_naive_groups(self, device, dtype):
33683359
# Check that grouped convolutions matches two half convolutions
33693360
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
@@ -3632,19 +3623,21 @@ def helper(
36323623
)
36333624

36343625
@onlyCUDA
3635-
@skipCUDAIfNotMiopenSuggestNHWC
36363626
@dtypes(torch.half, torch.float, torch.cfloat)
36373627
def test_conv_cudnn_nhwc(self, device, dtype):
36383628
def helper(n, c, h, w, out_channels, kernel_size, groups):
3639-
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3640-
memory_format=torch.channels_last
3641-
)
3629+
# randint with dtype=torch.cfloat fails with
3630+
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
3631+
# must create randint and randint_like using default int64, then cast to desired
3632+
input = torch.randint(
3633+
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
3634+
).to(dtype, memory_format=torch.channels_last)
36423635
input.requires_grad_()
36433636
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
36443637
device="cuda", dtype=dtype, memory_format=torch.channels_last
36453638
)
36463639
for p in conv.parameters():
3647-
p.data = torch.randint_like(p, -3, 3)
3640+
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
36483641

36493642
# use FP64 channels-first conv as reference
36503643
ref_input = input.detach().clone().contiguous().double().requires_grad_()
@@ -3658,7 +3651,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36583651
out = conv(input)
36593652
ref_out = ref_conv(ref_input)
36603653

3661-
grad = torch.randint_like(out, -3, 3)
3654+
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
36623655
ref_grad = grad.detach().clone().double().contiguous()
36633656

36643657
out.backward(grad)
@@ -3685,7 +3678,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36853678
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
36863679

36873680
@onlyCUDA
3688-
@skipCUDAIfRocm
36893681
@dtypes(torch.half, torch.float)
36903682
def test_conv_cudnn_ndhwc(self, device, dtype):
36913683
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
@@ -3815,7 +3807,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
38153807
)
38163808

38173809
@onlyCUDA
3818-
@skipCUDAIfNotMiopenSuggestNHWC
38193810
@tf32_on_and_off(0.05)
38203811
def test_conv_cudnn_mismatch_memory_format(self, device):
38213812
configs = [
@@ -3949,7 +3940,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
39493940
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
39503941

39513942
@onlyCUDA
3952-
@skipCUDAIfRocm
39533943
def test_convert_conv2d_weight_memory_format(self, device):
39543944
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
39553945
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
@@ -3969,7 +3959,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
39693959
self.assertTrue(out.is_contiguous(memory_format=memory_format))
39703960

39713961
@onlyCUDA
3972-
@skipCUDAIfRocm
39733962
def test_convert_conv3d_weight_memory_format(self, device):
39743963
input = torch.randint(
39753964
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device

test/test_nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
if TEST_WITH_ROCM:
6464
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
65+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6566

6667
# load_tests from common_utils is used to automatically filter tests for
6768
# sharding on sandcastle. This line silences flake warnings
@@ -3513,15 +3514,15 @@ def test_cudnn_forward_exception(self):
35133514
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
35143515

35153516
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3516-
@skipIfRocm
35173517
def test_cudnn_weight_format(self):
35183518
rnns = [
35193519
nn.LSTM(10, 20, batch_first=True),
35203520
nn.LSTM(10, 20, batch_first=True, proj_size=10),
35213521
nn.GRU(10, 20, batch_first=True),
35223522
nn.RNN(10, 20, batch_first=True)
35233523
]
3524-
first_warn = True
3524+
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
3525+
first_warn = False if torch.version.hip else True
35253526
for rnn in rnns:
35263527
rnn.cuda()
35273528
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

0 commit comments

Comments
 (0)