Skip to content

Commit bbf4b9b

Browse files
[release/2.6] Add 3D batchnorm tests (#2214)
Additive on top of #2209 3D batchhorm tests (NHWC3D and NCHW3D) NCHW 3D tests: ``` test_batchnorm_3D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.149s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.062s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.042s) test_batchnorm_3D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.008s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.028s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.020s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.023s) test_batchnorm_3D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.015s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_3D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Old batchnorm tests will have `2D` it their names ``` test_batchnorm_2D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.023s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.104s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.020s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_2D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Tested in `compute-rocm-dkms-no-npi-hipclang` image build 16062: `compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:16062_ubuntu22.04_py3.10_pytorch_lw_release-2.7_1fee1967` Tests can be run with environment variable `MIOPEN_ENABLE_LOGGING_CMD=1` to collect MIOpenDriver commands ``` MIOPEN_ENABLE_LOGGING_CMD=1 python test_nn.py -v -k test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NCDHW MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NCDHW ok ``` Co-authored-by: Jeff Daily <[email protected]>
1 parent 2045a75 commit bbf4b9b

File tree

1 file changed

+55
-20
lines changed

1 file changed

+55
-20
lines changed

test/test_nn.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
IS_PPC, \
3737
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
3838
skipIfTorchDynamo, skipIfRocmVersionLessThan, gcIfJetson, set_default_dtype
39-
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
39+
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
4040
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4141
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
4242
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@@ -5089,6 +5089,7 @@ def test_batchnorm_nhwc_cuda(self):
50895089
self.assertEqual(out1, out2)
50905090

50915091
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5092+
@parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D")
50925093
@parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x)
50935094
@parametrize_test(
50945095
# test verifies cudnn/miopen batchnorm with the reference backend or memory format
@@ -5104,14 +5105,11 @@ def test_batchnorm_nhwc_cuda(self):
51045105
[
51055106
("NCHW", "cpu", False, torch.float),
51065107
("NCHW", "cpu", True, torch.half),
5107-
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5108-
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5109-
subtest(("NCHW", "cpu", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5108+
("NCHW", "cpu", True, torch.bfloat16),
51105109

51115110
("NCHW", "native", False, torch.float),
51125111
("NCHW", "native", True, torch.half),
5113-
# this config failed for train and passed for inference on ROCm6.4
5114-
# subtest(("NCHW", "native", True, torch.bfloat16), decorators=[unittest.expectedFailure]),
5112+
("NCHW", "native", True, torch.bfloat16),
51155113

51165114
("NHWC", "cpu", False, torch.float),
51175115
("NHWC", "cpu", True, torch.half),
@@ -5123,21 +5121,41 @@ def test_batchnorm_nhwc_cuda(self):
51235121

51245122
("NHWC", "NCHW", False, torch.float),
51255123
("NHWC", "NCHW", True, torch.half),
5126-
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5127-
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5128-
subtest(("NHWC", "NCHW", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5124+
("NHWC", "NCHW", True, torch.bfloat16),
51295125
],
51305126
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51315127
)
5132-
def test_batchnorm(self, mode, memory_format, ref_backend, mixed, dtype):
5128+
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5129+
if torch.version.hip:
5130+
if self._testMethodName in ("test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5131+
"test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5132+
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5133+
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"
5134+
) and _get_torch_rocm_version() < (6, 4):
5135+
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5136+
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5137+
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
5138+
5139+
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16",
5140+
"test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16"
5141+
) and _get_torch_rocm_version() >= (6, 4):
5142+
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600")
5143+
5144+
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5145+
and _get_torch_rocm_version() < (6, 4):
5146+
self.skipTest("3D float16 NCHW train failed on ROCm<=6.3 ")
5147+
5148+
if dims == 3 and memory_format in ("NHWC", "NCHW"):
5149+
memory_format = memory_format + "3D"
5150+
51335151
def _create_tensor(size, memory_format, dtype, device):
51345152
t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device)
51355153
t = t.random_(1, 10)
51365154
return t
51375155

51385156
def _get_ref_device(backend: str , device: str):
51395157
# If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend
5140-
if backend in ("NHWC", "NCHW"):
5158+
if backend in ("NHWC", "NHWC3D", "NCHW", "NCHW3D"):
51415159
return device
51425160
if backend == "native":
51435161
return "cuda"
@@ -5150,9 +5168,11 @@ def _get_backend_memory_format(backend: str, memory_format: torch.memory_format)
51505168
# If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg
51515169
if backend == "NHWC":
51525170
return torch.channels_last
5153-
if backend == "NCHW":
5171+
if backend == "NHWC3D":
5172+
return torch.channels_last_3d
5173+
if backend in ("NCHW", "NCHW3D"):
51545174
return torch.contiguous_format
5155-
if memory_format in (torch.contiguous_format, torch.channels_last):
5175+
if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d):
51565176
return memory_format
51575177
raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}")
51585178

@@ -5161,10 +5181,24 @@ def _get_memory_format(t: torch.Tensor) -> torch.memory_format:
51615181
return torch.contiguous_format
51625182
if t.is_contiguous(memory_format=torch.channels_last):
51635183
return torch.channels_last
5184+
if t.is_contiguous(memory_format=torch.channels_last_3d):
5185+
return torch.channels_last_3d
5186+
return ValueError("Unsupported memory_format")
5187+
5188+
def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format:
5189+
if memory_format_name == "NHWC":
5190+
return torch.channels_last
5191+
elif memory_format_name == "NHWC3D":
5192+
return torch.channels_last_3d
5193+
elif memory_format_name in ("NCHW", "NCHW3D"):
5194+
return torch.contiguous_format
51645195
return ValueError("Unsupported memory_format")
51655196

51665197
def _create_backend(inp: torch.Tensor, mixed: bool = False):
5167-
mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5198+
5199+
mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) \
5200+
if inp.dim() == 4 else \
5201+
nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
51685202
return mod
51695203

51705204
def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
@@ -5191,12 +5225,13 @@ def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
51915225
self.assertEqual(mod.running_var, ref_mod.running_var)
51925226
self.assertEqual(inp.grad, ref_inp.grad)
51935227

5194-
def _train(memory_format, ref_backend, mixed, dtype):
5195-
memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5228+
def _train(memory_format_name, ref_backend, mixed, dtype):
5229+
memory_format = _get_memory_format_from_name(memory_format_name)
5230+
51965231
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
51975232
ref_device = _get_ref_device(ref_backend, device="cuda")
51985233

5199-
size = (4, 8, 2, 2)
5234+
size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2)
52005235
inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_()
52015236
grad = _create_tensor(size, memory_format, dtype, device="cuda")
52025237
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_()
@@ -5224,12 +5259,12 @@ def _train(memory_format, ref_backend, mixed, dtype):
52245259
# _test_batchnorm_train(input=input, grad=grad, mixed=mixed,
52255260
# ref_input=ref_input, ref_grad=ref_grad, ref_backend=ref_backend)
52265261

5227-
def _inference(memory_format, ref_backend, mixed, dtype):
5228-
memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5262+
def _inference(memory_format_name, ref_backend, mixed, dtype):
5263+
memory_format = _get_memory_format_from_name(memory_format_name)
52295264
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
52305265
ref_device = _get_ref_device(ref_backend, device="cuda")
52315266

5232-
size = (2, 64, 50, 50)
5267+
size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50)
52335268
inp = _create_tensor(size, memory_format, dtype, device="cuda")
52345269
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
52355270
mod = _create_backend(inp, mixed).eval()

0 commit comments

Comments
 (0)