Skip to content

Commit 846316e

Browse files
jerrymanniljeffdailydnikolaev-amdjithunnair-amd
authored
[release/2.9] fix miopen batchnorm changing output format (#2813)
cherry pick of pytorch#162112 Fixes #SWDEV-567460 Co-authored-by: Jeff Daily <[email protected]> Co-authored-by: Dmitry Nikolaev <[email protected]> Co-authored-by: Jithun Nair <[email protected]>
1 parent b8d38dd commit 846316e

File tree

6 files changed

+45
-55
lines changed

6 files changed

+45
-55
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
624624
if (backend == BatchNormBackend::Miopen) {
625625
return std::tuple_cat(
626626
at::miopen_batch_norm(
627-
input.contiguous(), weight.contiguous(), bias.contiguous(),
627+
input.contiguous(input.suggest_memory_format()),
628+
weight.contiguous(),
629+
bias.contiguous(),
628630
running_mean.defined() ? running_mean.contiguous() : running_mean,
629631
running_var.defined() ? running_var.contiguous() : running_var,
630632
training, momentum, eps),

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/NativeFunctions.h>
88
#else
99
#include <ATen/ops/empty.h>
10+
#include <ATen/ops/empty_like.h>
1011
#include <ATen/ops/miopen_batch_norm_native.h>
1112
#include <ATen/ops/miopen_batch_norm_backward_native.h>
1213
#endif
@@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
102103
mode = miopenBNSpatial;
103104
}
104105

105-
auto output_t = at::empty(input->sizes(), input->options());
106+
auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format());
106107
TensorArg output{ output_t, "output", 0 };
107108

108109
auto handle = getMiopenHandle();
@@ -170,20 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
170171
const std::optional<Tensor>& save_var_t_opt,
171172
double epsilon) {
172173
// See [Note: hacky wrapper removal for optional tensor]
173-
const Tensor& running_mean =
174-
running_mean_opt.value_or(Tensor());
175-
const Tensor& running_var =
176-
running_var_opt.value_or(Tensor());
177-
const Tensor& save_mean_t =
178-
save_mean_t_opt.value_or(Tensor());
179-
const Tensor& save_var_t =
180-
save_var_t_opt.value_or(Tensor());
181-
182-
TensorArg input{ input_t, "input", 1 },
183-
grad_output{ grad_output_t, "grad_output", 2 },
184-
weight{ weight_t, "weight", 3 },
185-
save_mean{ save_mean_t, "save_mean", 4 },
186-
save_var{ save_var_t, "save_var", 5 };
174+
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
175+
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
176+
177+
auto grad_output_contig =
178+
grad_output_t.contiguous(input_t.suggest_memory_format());
179+
TensorArg input{input_t, "input", 1},
180+
grad_output{grad_output_contig, "grad_output", 2},
181+
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
182+
save_var{save_var_t, "save_var", 5};
187183
CheckedFrom c = "miopen_batch_norm_backward";
188184

189185
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
@@ -195,7 +191,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
195191
}
196192
checkAllSameType(c, {input, grad_output});
197193
checkAllSameType(c, {weight, save_mean, save_var});
198-
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
194+
// TODO: is weight required to be contiguous?
195+
checkAllContiguous(c, {save_mean, save_var});
196+
// TODO: TensorArg check should start handle memory format
197+
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
198+
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
199199
checkDimRange(c, input, 2, 6 /* exclusive */);
200200
checkSameSize(c, input, grad_output);
201201
auto num_features = input->size(1);
@@ -210,7 +210,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
210210
mode = miopenBNSpatial;
211211
}
212212

213-
auto grad_input_t = at::empty(input->sizes(), input->options());
213+
auto grad_input_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
214214
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
215215
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
216216

test/functorch/test_ops.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,6 @@ class TestOperators(TestCase):
468468
), # Works on ROCm
469469
xfail("torch.ops.aten._flash_attention_forward"),
470470
xfail("torch.ops.aten._efficient_attention_forward"),
471-
# RuntimeError: Expected contiguous tensor, but got
472-
# non-contiguous tensor for argument #2 'grad_output'
473-
decorate(
474-
"_batch_norm_with_update",
475-
decorator=expectedFailureIf(TEST_WITH_ROCM),
476-
device_type="cuda",
477-
),
478471
}
479472
),
480473
)
@@ -2400,13 +2393,6 @@ def fn(input, weight, bias):
24002393
skip("sparse.sampled_addmm", ""),
24012394
skip("sparse.mm", "reduce"),
24022395
skip("native_layer_norm", "", device_type="cpu"),
2403-
# RuntimeError: Expected contiguous tensor, but got
2404-
# non-contiguous tensor for argument #2 'grad_output'
2405-
decorate(
2406-
"_batch_norm_with_update",
2407-
decorator=expectedFailureIf(TEST_WITH_ROCM),
2408-
device_type="cuda",
2409-
),
24102396
},
24112397
)
24122398
@opsToleranceOverride(

test/nn/test_convolution.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
skipCUDAIfMiopen,
3131
skipCUDAIfNoCudnn,
3232
skipCUDAIfNoMiopen,
33-
skipCUDAIfNotMiopenSuggestNHWC,
3433
skipCUDAIfRocm,
3534
skipMeta,
3635
skipMPS,
@@ -52,9 +51,7 @@
5251
parametrize as parametrize_test,
5352
run_tests,
5453
set_default_dtype,
55-
skipIfNotMiopenSuggestNHWC,
5654
skipIfRocmArch,
57-
skipIfRocmVersionLessThan,
5855
subtest,
5956
TEST_SCIPY,
6057
TEST_WITH_ROCM,
@@ -67,6 +64,7 @@
6764

6865
if TEST_WITH_ROCM:
6966
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
67+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
7068

7169

7270
if TEST_SCIPY:
@@ -718,7 +716,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
718716
# Almost identical to the above `test_Conv2d_naive_groups`
719717
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
720718
@tf32_on_and_off(0.001)
721-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
722719
def test_Conv2d_groups_nobias(self):
723720
dev_dtypes = [("cpu", torch.float)]
724721
if TEST_CUDA:
@@ -764,7 +761,6 @@ def test_Conv2d_groups_nobias(self):
764761
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
765762
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
766763
@tf32_on_and_off(0.001)
767-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
768764
def test_Conv2d_groups_nobias_v2(self):
769765
torch.manual_seed(123)
770766
dev_dtypes = [("cpu", torch.float)]
@@ -899,7 +895,6 @@ def test_conv_tbc(self):
899895

900896
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
901897
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
902-
@skipIfNotMiopenSuggestNHWC
903898
def test_grouped_conv_cudnn_nhwc_support(self):
904899
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
905900
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
@@ -3149,7 +3144,6 @@ def test_conv_noncontig_weights_and_bias(self, device):
31493144

31503145
@onlyCUDA
31513146
@largeTensorTest("12GB")
3152-
@skipIfRocmVersionLessThan((6, 0))
31533147
def test_conv_transposed_large(self, device):
31543148
dtype = torch.half if self.device_type == "cuda" else torch.float
31553149
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
@@ -3193,7 +3187,6 @@ def test_conv_transposed_large(self, device):
31933187
self.assertEqual(maxdiff3, 0)
31943188

31953189
@onlyCUDA
3196-
@skipCUDAIfRocm
31973190
@largeTensorTest("12GB")
31983191
def test_conv_large(self, device):
31993192
dtype = torch.half if self.device_type == "cuda" else torch.float
@@ -3226,7 +3219,6 @@ def test_conv_large(self, device):
32263219
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
32273220

32283221
@onlyCUDA
3229-
@skipCUDAIfRocm
32303222
@largeTensorTest("20GB", "cpu")
32313223
@largeTensorTest("60GB", "cuda")
32323224
def test_conv_large_batch_1(self, device):
@@ -3363,7 +3355,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
33633355
@dtypes(torch.float)
33643356
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
33653357
@tf32_on_and_off(0.001)
3366-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
33673358
def test_Conv2d_naive_groups(self, device, dtype):
33683359
# Check that grouped convolutions matches two half convolutions
33693360
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
@@ -3632,19 +3623,21 @@ def helper(
36323623
)
36333624

36343625
@onlyCUDA
3635-
@skipCUDAIfNotMiopenSuggestNHWC
36363626
@dtypes(torch.half, torch.float, torch.cfloat)
36373627
def test_conv_cudnn_nhwc(self, device, dtype):
36383628
def helper(n, c, h, w, out_channels, kernel_size, groups):
3639-
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3640-
memory_format=torch.channels_last
3641-
)
3629+
# randint with dtype=torch.cfloat fails with
3630+
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
3631+
# must create randint and randint_like using default int64, then cast to desired
3632+
input = torch.randint(
3633+
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
3634+
).to(dtype, memory_format=torch.channels_last)
36423635
input.requires_grad_()
36433636
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
36443637
device="cuda", dtype=dtype, memory_format=torch.channels_last
36453638
)
36463639
for p in conv.parameters():
3647-
p.data = torch.randint_like(p, -3, 3)
3640+
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
36483641

36493642
# use FP64 channels-first conv as reference
36503643
ref_input = input.detach().clone().contiguous().double().requires_grad_()
@@ -3658,7 +3651,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36583651
out = conv(input)
36593652
ref_out = ref_conv(ref_input)
36603653

3661-
grad = torch.randint_like(out, -3, 3)
3654+
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
36623655
ref_grad = grad.detach().clone().double().contiguous()
36633656

36643657
out.backward(grad)
@@ -3685,7 +3678,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36853678
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
36863679

36873680
@onlyCUDA
3688-
@skipCUDAIfRocm
36893681
@dtypes(torch.half, torch.float)
36903682
def test_conv_cudnn_ndhwc(self, device, dtype):
36913683
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
@@ -3815,7 +3807,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
38153807
)
38163808

38173809
@onlyCUDA
3818-
@skipCUDAIfNotMiopenSuggestNHWC
38193810
@tf32_on_and_off(0.05)
38203811
def test_conv_cudnn_mismatch_memory_format(self, device):
38213812
configs = [
@@ -3949,7 +3940,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
39493940
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
39503941

39513942
@onlyCUDA
3952-
@skipCUDAIfRocm
39533943
def test_convert_conv2d_weight_memory_format(self, device):
39543944
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
39553945
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
@@ -3969,7 +3959,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
39693959
self.assertTrue(out.is_contiguous(memory_format=memory_format))
39703960

39713961
@onlyCUDA
3972-
@skipCUDAIfRocm
39733962
def test_convert_conv3d_weight_memory_format(self, device):
39743963
input = torch.randint(
39753964
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device

test/test_nn.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
if TEST_WITH_ROCM:
6464
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
65+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
6566

6667
# load_tests from common_utils is used to automatically filter tests for
6768
# sharding on sandcastle. This line silences flake warnings
@@ -3513,15 +3514,15 @@ def test_cudnn_forward_exception(self):
35133514
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
35143515

35153516
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3516-
@skipIfRocm
35173517
def test_cudnn_weight_format(self):
35183518
rnns = [
35193519
nn.LSTM(10, 20, batch_first=True),
35203520
nn.LSTM(10, 20, batch_first=True, proj_size=10),
35213521
nn.GRU(10, 20, batch_first=True),
35223522
nn.RNN(10, 20, batch_first=True)
35233523
]
3524-
first_warn = True
3524+
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
3525+
first_warn = False if torch.version.hip else True
35253526
for rnn in rnns:
35263527
rnn.cuda()
35273528
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
@@ -5170,6 +5171,18 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51705171
("NCHW", "native", False, torch.float),
51715172
("NCHW", "native", True, torch.half),
51725173
("NCHW", "native", True, torch.bfloat16),
5174+
5175+
("NHWC", "cpu", False, torch.float),
5176+
("NHWC", "cpu", True, torch.half),
5177+
("NHWC", "cpu", True, torch.bfloat16),
5178+
5179+
("NHWC", "native", False, torch.float),
5180+
("NHWC", "native", True, torch.half),
5181+
("NHWC", "native", True, torch.bfloat16),
5182+
5183+
("NHWC", "NCHW", False, torch.float),
5184+
("NHWC", "NCHW", True, torch.half),
5185+
("NHWC", "NCHW", True, torch.bfloat16),
51735186
],
51745187
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51755188
)

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2801,7 +2801,7 @@
28012801
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
28022802

28032803
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
2804-
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
2804+
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
28052805
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)
28062806

28072807
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)

0 commit comments

Comments
 (0)