3737 IS_PPC, \
3838 parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
3939 skipIfTorchDynamo, skipIfRocmVersionLessThan, gcIfJetson, set_default_dtype
40- from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
40+ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
4141from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4242 module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
4343 ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@@ -5140,6 +5140,7 @@ def test_batchnorm_nhwc_cuda(self):
51405140 self.assertEqual(out1, out2)
51415141
51425142 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5143+ @parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D")
51435144 @parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x)
51445145 @parametrize_test(
51455146 # test verifies cudnn/miopen batchnorm with the reference backend or memory format
@@ -5155,14 +5156,11 @@ def test_batchnorm_nhwc_cuda(self):
51555156 [
51565157 ("NCHW", "cpu", False, torch.float),
51575158 ("NCHW", "cpu", True, torch.half),
5158- # NCHW bfloat16 path uses native kernels for rocm<=6.3
5159- # train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5160- subtest(("NCHW", "cpu", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5159+ ("NCHW", "cpu", True, torch.bfloat16),
51615160
51625161 ("NCHW", "native", False, torch.float),
51635162 ("NCHW", "native", True, torch.half),
5164- # this config failed for train and passed for inference on ROCm6.4
5165- # subtest(("NCHW", "native", True, torch.bfloat16), decorators=[unittest.expectedFailure]),
5163+ ("NCHW", "native", True, torch.bfloat16),
51665164
51675165 ("NHWC", "cpu", False, torch.float),
51685166 ("NHWC", "cpu", True, torch.half),
@@ -5174,21 +5172,41 @@ def test_batchnorm_nhwc_cuda(self):
51745172
51755173 ("NHWC", "NCHW", False, torch.float),
51765174 ("NHWC", "NCHW", True, torch.half),
5177- # NCHW bfloat16 path uses native kernels for rocm<=6.3
5178- # train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5179- subtest(("NHWC", "NCHW", True, torch.bfloat16), decorators=[skipIfRocmVersionLessThan((6, 4))]),
5175+ ("NHWC", "NCHW", True, torch.bfloat16),
51805176 ],
51815177 name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51825178 )
5183- def test_batchnorm(self, mode, memory_format, ref_backend, mixed, dtype):
5179+ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5180+ if torch.version.hip:
5181+ if self._testMethodName in ("test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5182+ "test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5183+ "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5184+ "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"
5185+ ) and _get_torch_rocm_version() < (6, 4):
5186+ # NCHW bfloat16 path uses native kernels for rocm<=6.3
5187+ # train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600
5188+ self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
5189+
5190+ if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16",
5191+ "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16"
5192+ ) and _get_torch_rocm_version() >= (6, 4):
5193+ self.skipTest("bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600")
5194+
5195+ if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5196+ and _get_torch_rocm_version() < (6, 4):
5197+ self.skipTest("3D float16 NCHW train failed on ROCm<=6.3 ")
5198+
5199+ if dims == 3 and memory_format in ("NHWC", "NCHW"):
5200+ memory_format = memory_format + "3D"
5201+
51845202 def _create_tensor(size, memory_format, dtype, device):
51855203 t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device)
51865204 t = t.random_(1, 10)
51875205 return t
51885206
51895207 def _get_ref_device(backend: str , device: str):
51905208 # If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend
5191- if backend in ("NHWC", "NCHW"):
5209+ if backend in ("NHWC", "NHWC3D", " NCHW", "NCHW3D "):
51925210 return device
51935211 if backend == "native":
51945212 return "cuda"
@@ -5201,9 +5219,11 @@ def _get_backend_memory_format(backend: str, memory_format: torch.memory_format)
52015219 # If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg
52025220 if backend == "NHWC":
52035221 return torch.channels_last
5204- if backend == "NCHW":
5222+ if backend == "NHWC3D":
5223+ return torch.channels_last_3d
5224+ if backend in ("NCHW", "NCHW3D"):
52055225 return torch.contiguous_format
5206- if memory_format in (torch.contiguous_format, torch.channels_last):
5226+ if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d ):
52075227 return memory_format
52085228 raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}")
52095229
@@ -5212,10 +5232,24 @@ def _get_memory_format(t: torch.Tensor) -> torch.memory_format:
52125232 return torch.contiguous_format
52135233 if t.is_contiguous(memory_format=torch.channels_last):
52145234 return torch.channels_last
5235+ if t.is_contiguous(memory_format=torch.channels_last_3d):
5236+ return torch.channels_last_3d
5237+ return ValueError("Unsupported memory_format")
5238+
5239+ def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format:
5240+ if memory_format_name == "NHWC":
5241+ return torch.channels_last
5242+ elif memory_format_name == "NHWC3D":
5243+ return torch.channels_last_3d
5244+ elif memory_format_name in ("NCHW", "NCHW3D"):
5245+ return torch.contiguous_format
52155246 return ValueError("Unsupported memory_format")
52165247
52175248 def _create_backend(inp: torch.Tensor, mixed: bool = False):
5218- mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5249+
5250+ mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) \
5251+ if inp.dim() == 4 else \
5252+ nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
52195253 return mod
52205254
52215255 def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
@@ -5242,12 +5276,13 @@ def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
52425276 self.assertEqual(mod.running_var, ref_mod.running_var)
52435277 self.assertEqual(inp.grad, ref_inp.grad)
52445278
5245- def _train(memory_format, ref_backend, mixed, dtype):
5246- memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5279+ def _train(memory_format_name, ref_backend, mixed, dtype):
5280+ memory_format = _get_memory_format_from_name(memory_format_name)
5281+
52475282 ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
52485283 ref_device = _get_ref_device(ref_backend, device="cuda")
52495284
5250- size = (4, 8, 2, 2)
5285+ size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2 )
52515286 inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_()
52525287 grad = _create_tensor(size, memory_format, dtype, device="cuda")
52535288 ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_()
@@ -5275,12 +5310,12 @@ def _train(memory_format, ref_backend, mixed, dtype):
52755310 # _test_batchnorm_train(input=input, grad=grad, mixed=mixed,
52765311 # ref_input=ref_input, ref_grad=ref_grad, ref_backend=ref_backend)
52775312
5278- def _inference(memory_format , ref_backend, mixed, dtype):
5279- memory_format = torch.contiguous_format if memory_format == "NCHW" else torch.channels_last
5313+ def _inference(memory_format_name , ref_backend, mixed, dtype):
5314+ memory_format = _get_memory_format_from_name(memory_format_name)
52805315 ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
52815316 ref_device = _get_ref_device(ref_backend, device="cuda")
52825317
5283- size = (2, 64, 50, 50)
5318+ size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50 )
52845319 inp = _create_tensor(size, memory_format, dtype, device="cuda")
52855320 ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
52865321 mod = _create_backend(inp, mixed).eval()
0 commit comments