Skip to content

Commit d3985e1

Browse files
committed
[release/2.8] fix miopen batchnorm changing output format
cherry pick of pytorch#162112
1 parent aeb6421 commit d3985e1

File tree

3 files changed

+20
-36
lines changed

3 files changed

+20
-36
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/NativeFunctions.h>
88
#else
99
#include <ATen/ops/empty.h>
10+
#include <ATen/ops/empty_like.h>
1011
#include <ATen/ops/miopen_batch_norm_native.h>
1112
#include <ATen/ops/miopen_batch_norm_backward_native.h>
1213
#endif
@@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
102103
mode = miopenBNSpatial;
103104
}
104105

105-
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
106+
auto output_t = at::empty(input_t, input_t.options(), input_t.suggest_memory_format());
106107
TensorArg output{ output_t, "output", 0 };
107108

108109
auto handle = getMiopenHandle();
@@ -170,22 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
170171
const std::optional<Tensor>& save_var_t_opt,
171172
double epsilon) {
172173
// See [Note: hacky wrapper removal for optional tensor]
173-
const Tensor& running_mean =
174-
running_mean_opt.value_or(Tensor());
175-
const Tensor& running_var =
176-
running_var_opt.value_or(Tensor());
177-
const Tensor& save_mean_t =
178-
save_mean_t_opt.value_or(Tensor());
179-
const Tensor& save_var_t =
180-
save_var_t_opt.value_or(Tensor());
174+
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
175+
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
181176

182177
auto grad_output_contig =
183178
grad_output_t.contiguous(input_t.suggest_memory_format());
184-
TensorArg input{ input_t, "input", 1 },
185-
grad_output{ grad_output_contig, "grad_output", 2 },
186-
weight{ weight_t, "weight", 3 },
187-
save_mean{ save_mean_t, "save_mean", 4 },
188-
save_var{ save_var_t, "save_var", 5 };
179+
TensorArg input{input_t, "input", 1},
180+
grad_output{grad_output_contig, "grad_output", 2},
181+
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
182+
save_var{save_var_t, "save_var", 5};
189183
CheckedFrom c = "miopen_batch_norm_backward";
190184

191185
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});

test/nn/test_convolution.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
skipCUDAIfMiopen,
3131
skipCUDAIfNoCudnn,
3232
skipCUDAIfNoMiopen,
33-
skipCUDAIfNotMiopenSuggestNHWC,
3433
skipCUDAIfRocm,
3534
skipMeta,
3635
skipMPS,
@@ -51,8 +50,6 @@
5150
parametrize as parametrize_test,
5251
run_tests,
5352
set_default_dtype,
54-
skipIfNotMiopenSuggestNHWC,
55-
skipIfRocmVersionLessThan,
5653
subtest,
5754
TEST_SCIPY,
5855
TEST_WITH_ROCM,
@@ -64,6 +61,7 @@
6461

6562
if TEST_WITH_ROCM:
6663
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
64+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6765

6866

6967
if TEST_SCIPY:
@@ -715,7 +713,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
715713
# Almost identical to the above `test_Conv2d_naive_groups`
716714
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
717715
@tf32_on_and_off(0.001)
718-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
719716
def test_Conv2d_groups_nobias(self):
720717
dev_dtypes = [("cpu", torch.float)]
721718
if TEST_CUDA:
@@ -761,7 +758,6 @@ def test_Conv2d_groups_nobias(self):
761758
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
762759
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
763760
@tf32_on_and_off(0.001)
764-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
765761
def test_Conv2d_groups_nobias_v2(self):
766762
torch.manual_seed(123)
767763
dev_dtypes = [("cpu", torch.float)]
@@ -896,7 +892,6 @@ def test_conv_tbc(self):
896892

897893
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
898894
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
899-
@skipIfNotMiopenSuggestNHWC
900895
def test_grouped_conv_cudnn_nhwc_support(self):
901896
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
902897
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
@@ -3145,7 +3140,6 @@ def test_conv_noncontig_weights_and_bias(self, device):
31453140

31463141
@onlyCUDA
31473142
@largeTensorTest("12GB")
3148-
@skipIfRocmVersionLessThan((6, 0))
31493143
def test_conv_transposed_large(self, device):
31503144
dtype = torch.half if self.device_type == "cuda" else torch.float
31513145
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
@@ -3189,7 +3183,6 @@ def test_conv_transposed_large(self, device):
31893183
self.assertEqual(maxdiff3, 0)
31903184

31913185
@onlyCUDA
3192-
@skipCUDAIfRocm
31933186
@largeTensorTest("12GB")
31943187
def test_conv_large(self, device):
31953188
dtype = torch.half if self.device_type == "cuda" else torch.float
@@ -3222,7 +3215,6 @@ def test_conv_large(self, device):
32223215
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
32233216

32243217
@onlyCUDA
3225-
@skipCUDAIfRocm
32263218
@largeTensorTest("20GB", "cpu")
32273219
@largeTensorTest("60GB", "cuda")
32283220
def test_conv_large_batch_1(self, device):
@@ -3370,7 +3362,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
33703362
@dtypes(torch.float)
33713363
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
33723364
@tf32_on_and_off(0.001)
3373-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
33743365
def test_Conv2d_naive_groups(self, device, dtype):
33753366
# Check that grouped convolutions matches two half convolutions
33763367
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
@@ -3639,19 +3630,21 @@ def helper(
36393630
)
36403631

36413632
@onlyCUDA
3642-
@skipCUDAIfNotMiopenSuggestNHWC
36433633
@dtypes(torch.half, torch.float, torch.cfloat)
36443634
def test_conv_cudnn_nhwc(self, device, dtype):
36453635
def helper(n, c, h, w, out_channels, kernel_size, groups):
3646-
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3647-
memory_format=torch.channels_last
3648-
)
3636+
# randint with dtype=torch.cfloat fails with
3637+
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
3638+
# must create randint and randint_like using default int64, then cast to desired
3639+
input = torch.randint(
3640+
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
3641+
).to(dtype, memory_format=torch.channels_last)
36493642
input.requires_grad_()
36503643
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
36513644
device="cuda", dtype=dtype, memory_format=torch.channels_last
36523645
)
36533646
for p in conv.parameters():
3654-
p.data = torch.randint_like(p, -3, 3)
3647+
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
36553648

36563649
# use FP64 channels-first conv as reference
36573650
ref_input = input.detach().clone().contiguous().double().requires_grad_()
@@ -3665,7 +3658,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36653658
out = conv(input)
36663659
ref_out = ref_conv(ref_input)
36673660

3668-
grad = torch.randint_like(out, -3, 3)
3661+
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
36693662
ref_grad = grad.detach().clone().double().contiguous()
36703663

36713664
out.backward(grad)
@@ -3692,7 +3685,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36923685
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
36933686

36943687
@onlyCUDA
3695-
@skipCUDAIfRocm
36963688
@dtypes(torch.half, torch.float)
36973689
def test_conv_cudnn_ndhwc(self, device, dtype):
36983690
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
@@ -3822,7 +3814,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
38223814
)
38233815

38243816
@onlyCUDA
3825-
@skipCUDAIfNotMiopenSuggestNHWC
38263817
@tf32_on_and_off(0.05)
38273818
def test_conv_cudnn_mismatch_memory_format(self, device):
38283819
configs = [
@@ -3955,7 +3946,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
39553946
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
39563947

39573948
@onlyCUDA
3958-
@skipCUDAIfRocm
39593949
def test_convert_conv2d_weight_memory_format(self, device):
39603950
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
39613951
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
@@ -3975,7 +3965,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
39753965
self.assertTrue(out.is_contiguous(memory_format=memory_format))
39763966

39773967
@onlyCUDA
3978-
@skipCUDAIfRocm
39793968
def test_convert_conv3d_weight_memory_format(self, device):
39803969
input = torch.randint(
39813970
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device

test/test_nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060

6161
if TEST_WITH_ROCM:
6262
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
63+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6364

6465
# load_tests from common_utils is used to automatically filter tests for
6566
# sharding on sandcastle. This line silences flake warnings
@@ -3496,15 +3497,15 @@ def test_cudnn_forward_exception(self):
34963497
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
34973498

34983499
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3499-
@skipIfRocm
35003500
def test_cudnn_weight_format(self):
35013501
rnns = [
35023502
nn.LSTM(10, 20, batch_first=True),
35033503
nn.LSTM(10, 20, batch_first=True, proj_size=10),
35043504
nn.GRU(10, 20, batch_first=True),
35053505
nn.RNN(10, 20, batch_first=True)
35063506
]
3507-
first_warn = True
3507+
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
3508+
first_warn = False if torch.version.hip else True
35083509
for rnn in rnns:
35093510
rnn.cuda()
35103511
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

0 commit comments

Comments
 (0)