Skip to content

Commit 272d56b

Browse files
[rocm7.1_internal_testing] Fix nhwc batchnorm after ifu rocm7.1 (#2591)
NHWC Batchnorm was completely lost after rocm7.1_internal_testing IFU This PR: - fixes condition to invoke MIOpen for NHWC Batchnorm - removes duplicated test_batchnorm 2D/3D train/inference tests Fixes SWDEV-553036
1 parent 28f820a commit 272d56b

File tree

2 files changed

+0
-171
lines changed

2 files changed

+0
-171
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,6 @@ BatchNormBackend _select_batch_norm_backend(
554554
&& weight.defined() && bias.defined()
555555
&& ((running_mean.defined() && running_var.defined())
556556
|| (!running_mean.defined() && !running_var.defined() && training))
557-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
558-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
559557
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
560558
#if (defined(USE_ROCM) && ROCM_VERSION >= 60500)
561559
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM)

test/test_nn.py

Lines changed: 0 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -5145,175 +5145,6 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51455145
self.assertTrue(torch.equal(running_mean, bn.running_mean))
51465146
self.assertTrue(torch.equal(running_var, bn.running_var))
51475147

5148-
5149-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5150-
@parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D")
5151-
@parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x)
5152-
@parametrize_test(
5153-
# test verifies cudnn/miopen batchnorm with the reference backend or memory format
5154-
# memory_format - one of ("NCHW", NHWC")
5155-
# ref_backend - one of ("cpu", "native", "NCHW", "NHWC")
5156-
# "cpu" - cpu backend with the same memory_format will be used as reference
5157-
# "native" - native backend (`with torch.backends.cudnn.flags(enabled=False)`)
5158-
# with the same memory_format will be used
5159-
# "NCHW" or "NHWC" - the same backend will be used but another memory format
5160-
# mixed - True or False. Mixed batchnorm mode where inputs are 16-bit and batchnorm is fp32
5161-
#
5162-
"memory_format,ref_backend,mixed,dtype",
5163-
[
5164-
("NCHW", "cpu", False, torch.float),
5165-
("NCHW", "cpu", True, torch.half),
5166-
("NCHW", "cpu", True, torch.bfloat16),
5167-
5168-
("NCHW", "native", False, torch.float),
5169-
("NCHW", "native", True, torch.half),
5170-
("NCHW", "native", True, torch.bfloat16),
5171-
],
5172-
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
5173-
)
5174-
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5175-
if torch.version.cuda:
5176-
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5177-
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
5178-
self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
5179-
"https://github.com/pytorch/pytorch/issues/156513")
5180-
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5181-
self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
5182-
5183-
if torch.version.hip:
5184-
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5185-
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
5186-
and _get_torch_rocm_version() < (6, 4):
5187-
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5188-
# train failed on rocm<=6.3 due to native tolerance issue
5189-
# https://github.com/pytorch/pytorch/issues/156513
5190-
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
5191-
5192-
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16",
5193-
"test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16") \
5194-
and _get_torch_rocm_version() >= (6, 4):
5195-
# https://github.com/pytorch/pytorch/issues/156513
5196-
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
5197-
5198-
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5199-
and _get_torch_rocm_version() < (7, 0):
5200-
self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
5201-
5202-
if dims == 3 and memory_format in ("NHWC", "NCHW"):
5203-
memory_format = memory_format + "3D"
5204-
5205-
def _create_tensor(size, memory_format, dtype, device):
5206-
t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device)
5207-
t = t.random_(1, 10)
5208-
return t
5209-
5210-
def _get_ref_device(backend: str , device: str):
5211-
# If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend
5212-
if backend in ("NHWC", "NHWC3D", "NCHW", "NCHW3D"):
5213-
return device
5214-
if backend == "native":
5215-
return "cuda"
5216-
if backend == "cpu":
5217-
return "cpu"
5218-
else:
5219-
raise ValueError("Unknown backend")
5220-
5221-
def _get_backend_memory_format(backend: str, memory_format: torch.memory_format) -> torch.memory_format:
5222-
# If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg
5223-
if backend == "NHWC":
5224-
return torch.channels_last
5225-
if backend == "NHWC3D":
5226-
return torch.channels_last_3d
5227-
if backend in ("NCHW", "NCHW3D"):
5228-
return torch.contiguous_format
5229-
if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d):
5230-
return memory_format
5231-
raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}")
5232-
5233-
def _get_memory_format(t: torch.Tensor) -> torch.memory_format:
5234-
if t.is_contiguous(memory_format=torch.contiguous_format):
5235-
return torch.contiguous_format
5236-
if t.is_contiguous(memory_format=torch.channels_last):
5237-
return torch.channels_last
5238-
if t.is_contiguous(memory_format=torch.channels_last_3d):
5239-
return torch.channels_last_3d
5240-
return ValueError("Unsupported memory_format")
5241-
5242-
def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format:
5243-
if memory_format_name == "NHWC":
5244-
return torch.channels_last
5245-
elif memory_format_name == "NHWC3D":
5246-
return torch.channels_last_3d
5247-
elif memory_format_name in ("NCHW", "NCHW3D"):
5248-
return torch.contiguous_format
5249-
return ValueError("Unsupported memory_format")
5250-
5251-
def _create_backend(inp: torch.Tensor, mixed: bool = False):
5252-
if inp.dim() == 4:
5253-
return nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5254-
else:
5255-
return nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5256-
5257-
def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
5258-
mod = _create_backend(inp, mixed).train()
5259-
mod.weight.data.uniform_()
5260-
mod.bias.data.uniform_()
5261-
5262-
ref_mod = _create_backend(ref_inp, mixed).train()
5263-
ref_mod.load_state_dict(mod.state_dict())
5264-
5265-
out = mod(inp)
5266-
out.backward(grad)
5267-
5268-
with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext():
5269-
ref_out = ref_mod(ref_inp)
5270-
ref_out.backward(ref_grad)
5271-
5272-
self.assertTrue(out.is_contiguous(memory_format=_get_memory_format(inp)))
5273-
self.assertTrue(ref_out.is_contiguous(memory_format=_get_memory_format(ref_inp)))
5274-
self.assertEqual(out, ref_out)
5275-
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
5276-
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
5277-
self.assertEqual(mod.running_mean, ref_mod.running_mean)
5278-
self.assertEqual(mod.running_var, ref_mod.running_var)
5279-
self.assertEqual(inp.grad, ref_inp.grad)
5280-
5281-
def _train(memory_format_name, ref_backend, mixed, dtype):
5282-
memory_format = _get_memory_format_from_name(memory_format_name)
5283-
5284-
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
5285-
ref_device = _get_ref_device(ref_backend, device="cuda")
5286-
5287-
size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2)
5288-
inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_()
5289-
grad = _create_tensor(size, memory_format, dtype, device="cuda")
5290-
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_()
5291-
ref_grad = grad.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
5292-
5293-
_test_batchnorm_train(inp=inp, grad=grad, mixed=mixed,
5294-
ref_inp=ref_inp, ref_grad=ref_grad, ref_backend=ref_backend)
5295-
5296-
def _inference(memory_format_name, ref_backend, mixed, dtype):
5297-
memory_format = _get_memory_format_from_name(memory_format_name)
5298-
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
5299-
ref_device = _get_ref_device(ref_backend, device="cuda")
5300-
5301-
size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50)
5302-
inp = _create_tensor(size, memory_format, dtype, device="cuda")
5303-
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
5304-
mod = _create_backend(inp, mixed).eval()
5305-
ref_mod = _create_backend(ref_inp, mixed).eval()
5306-
5307-
out = mod(inp)
5308-
with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext():
5309-
ref_out = ref_mod(ref_inp)
5310-
self.assertEqual(out, ref_out)
5311-
5312-
if mode == "train":
5313-
_train(memory_format, ref_backend, mixed, dtype)
5314-
else:
5315-
_inference(memory_format, ref_backend, mixed, dtype)
5316-
53175148
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
53185149
def test_batchnorm_nhwc_cuda(self):
53195150
for dtype in (torch.half, torch.float):

0 commit comments

Comments
 (0)