Skip to content

Commit e4d62b1

Browse files
okakarpadnikolaev-amdjeffdaily
authored andcommitted
[AUTOGENERATED] [release/2.7] Add 3D batchnorm tests (#2243)
Cherry-pick of #2214 Co-authored-by: Dmitry Nikolaev <[email protected]> Co-authored-by: Jeff Daily <[email protected]> (cherry picked from commit 5631e07)
1 parent ae17c3a commit e4d62b1

File tree

1 file changed

+55
-20
lines changed

1 file changed

+55
-20
lines changed

test/test_nn.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
IS_PPC, \
3838
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
3939
skipIfTorchDynamo, skipIfRocmVersionLessThan, gcIfJetson, set_default_dtype
40-
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
40+
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
4141
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4242
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
4343
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@@ -5140,6 +5140,7 @@ def test_batchnorm_nhwc_cuda(self):
51405140
self.assertEqual(out1, out2)
51415141

51425142
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5143+
@parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D")
51435144
@parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x)
51445145
@parametrize_test(
51455146
# test verifies cudnn/miopen batchnorm with the reference backend or memory format
@@ -5155,14 +5156,11 @@ def test_batchnorm_nhwc_cuda(self):
51555156
[
51565157
("NCHW", "cpu", False, torch.float),
51575158
("NCHW", "cpu", True, torch.half),
5158-
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5159-
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5160-
subtest(("NCHW", "cpu", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5159+
("NCHW", "cpu", True, torch.bfloat16),
51615160

51625161
("NCHW", "native", False, torch.float),
51635162
("NCHW", "native", True, torch.half),
5164-
# this config failed for train and passed for inference on ROCm6.4
5165-
# subtest(("NCHW", "native", True, torch.bfloat16), decorators=[unittest.expectedFailure]),
5163+
("NCHW", "native", True, torch.bfloat16),
51665164

51675165
("NHWC", "cpu", False, torch.float),
51685166
("NHWC", "cpu", True, torch.half),
@@ -5174,21 +5172,41 @@ def test_batchnorm_nhwc_cuda(self):
51745172

51755173
("NHWC", "NCHW", False, torch.float),
51765174
("NHWC", "NCHW", True, torch.half),
5177-
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5178-
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5179-
subtest(("NHWC", "NCHW", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5175+
("NHWC", "NCHW", True, torch.bfloat16),
51805176
],
51815177
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51825178
)
5183-
def test_batchnorm(self, mode, memory_format, ref_backend, mixed, dtype):
5179+
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5180+
if torch.version.hip:
5181+
if self._testMethodName in ("test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5182+
"test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5183+
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5184+
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"
5185+
) and _get_torch_rocm_version() < (6, 4):
5186+
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5187+
# train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5188+
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
5189+
5190+
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16",
5191+
"test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16"
5192+
) and _get_torch_rocm_version() >= (6, 4):
5193+
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600")
5194+
5195+
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5196+
and _get_torch_rocm_version() < (6, 4):
5197+
self.skipTest("3D float16 NCHW train failed on ROCm<=6.3 ")
5198+
5199+
if dims == 3 and memory_format in ("NHWC", "NCHW"):
5200+
memory_format = memory_format + "3D"
5201+
51845202
def _create_tensor(size, memory_format, dtype, device):
51855203
t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device)
51865204
t = t.random_(1, 10)
51875205
return t
51885206

51895207
def _get_ref_device(backend: str , device: str):
51905208
# If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend
5191-
if backend in ("NHWC", "NCHW"):
5209+
if backend in ("NHWC", "NHWC3D", "NCHW", "NCHW3D"):
51925210
return device
51935211
if backend == "native":
51945212
return "cuda"
@@ -5201,9 +5219,11 @@ def _get_backend_memory_format(backend: str, memory_format: torch.memory_format)
52015219
# If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg
52025220
if backend == "NHWC":
52035221
return torch.channels_last
5204-
if backend == "NCHW":
5222+
if backend == "NHWC3D":
5223+
return torch.channels_last_3d
5224+
if backend in ("NCHW", "NCHW3D"):
52055225
return torch.contiguous_format
5206-
if memory_format in (torch.contiguous_format, torch.channels_last):
5226+
if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d):
52075227
return memory_format
52085228
raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}")
52095229

@@ -5212,10 +5232,24 @@ def _get_memory_format(t: torch.Tensor) -> torch.memory_format:
52125232
return torch.contiguous_format
52135233
if t.is_contiguous(memory_format=torch.channels_last):
52145234
return torch.channels_last
5235+
if t.is_contiguous(memory_format=torch.channels_last_3d):
5236+
return torch.channels_last_3d
5237+
return ValueError("Unsupported memory_format")
5238+
5239+
def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format:
5240+
if memory_format_name == "NHWC":
5241+
return torch.channels_last
5242+
elif memory_format_name == "NHWC3D":
5243+
return torch.channels_last_3d
5244+
elif memory_format_name in ("NCHW", "NCHW3D"):
5245+
return torch.contiguous_format
52155246
return ValueError("Unsupported memory_format")
52165247

52175248
def _create_backend(inp: torch.Tensor, mixed: bool = False):
5218-
mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5249+
5250+
mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) \
5251+
if inp.dim() == 4 else \
5252+
nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
52195253
return mod
52205254

52215255
def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
@@ -5242,12 +5276,13 @@ def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
52425276
self.assertEqual(mod.running_var, ref_mod.running_var)
52435277
self.assertEqual(inp.grad, ref_inp.grad)
52445278

5245-
def _train(memory_format, ref_backend, mixed, dtype):
5246-
memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5279+
def _train(memory_format_name, ref_backend, mixed, dtype):
5280+
memory_format = _get_memory_format_from_name(memory_format_name)
5281+
52475282
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
52485283
ref_device = _get_ref_device(ref_backend, device="cuda")
52495284

5250-
size = (4, 8, 2, 2)
5285+
size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2)
52515286
inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_()
52525287
grad = _create_tensor(size, memory_format, dtype, device="cuda")
52535288
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_()
@@ -5275,12 +5310,12 @@ def _train(memory_format, ref_backend, mixed, dtype):
52755310
# _test_batchnorm_train(input=input, grad=grad, mixed=mixed,
52765311
# ref_input=ref_input, ref_grad=ref_grad, ref_backend=ref_backend)
52775312

5278-
def _inference(memory_format, ref_backend, mixed, dtype):
5279-
memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5313+
def _inference(memory_format_name, ref_backend, mixed, dtype):
5314+
memory_format = _get_memory_format_from_name(memory_format_name)
52805315
ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
52815316
ref_device = _get_ref_device(ref_backend, device="cuda")
52825317

5283-
size = (2, 64, 50, 50)
5318+
size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50)
52845319
inp = _create_tensor(size, memory_format, dtype, device="cuda")
52855320
ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
52865321
mod = _create_backend(inp, mixed).eval()

0 commit comments

Comments
 (0)