@@ -5145,175 +5145,6 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51455145 self.assertTrue(torch.equal(running_mean, bn.running_mean))
51465146 self.assertTrue(torch.equal(running_var, bn.running_var))
51475147
5148-
5149- @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5150- @parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D")
5151- @parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x)
5152- @parametrize_test(
5153- # test verifies cudnn/miopen batchnorm with the reference backend or memory format
5154- # memory_format - one of ("NCHW", NHWC")
5155- # ref_backend - one of ("cpu", "native", "NCHW", "NHWC")
5156- # "cpu" - cpu backend with the same memory_format will be used as reference
5157- # "native" - native backend (`with torch.backends.cudnn.flags(enabled=False)`)
5158- # with the same memory_format will be used
5159- # "NCHW" or "NHWC" - the same backend will be used but another memory format
5160- # mixed - True or False. Mixed batchnorm mode where inputs are 16-bit and batchnorm is fp32
5161- #
5162- "memory_format,ref_backend,mixed,dtype",
5163- [
5164- ("NCHW", "cpu", False, torch.float),
5165- ("NCHW", "cpu", True, torch.half),
5166- ("NCHW", "cpu", True, torch.bfloat16),
5167-
5168- ("NCHW", "native", False, torch.float),
5169- ("NCHW", "native", True, torch.half),
5170- ("NCHW", "native", True, torch.bfloat16),
5171- ],
5172- name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
5173- )
5174- def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5175- if torch.version.cuda:
5176- if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5177- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
5178- self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
5179- "https://github.com/pytorch/pytorch/issues/156513")
5180- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5181- self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
5182-
5183- if torch.version.hip:
5184- if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5185- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
5186- and _get_torch_rocm_version() < (6, 4):
5187- # NCHW bfloat16 path uses native kernels for rocm<=6.3
5188- # train failed on rocm<=6.3 due to native tolerance issue
5189- # https://github.com/pytorch/pytorch/issues/156513
5190- self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
5191-
5192- if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16",
5193- "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16") \
5194- and _get_torch_rocm_version() >= (6, 4):
5195- # https://github.com/pytorch/pytorch/issues/156513
5196- self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
5197-
5198- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5199- and _get_torch_rocm_version() < (7, 0):
5200- self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
5201-
5202- if dims == 3 and memory_format in ("NHWC", "NCHW"):
5203- memory_format = memory_format + "3D"
5204-
5205- def _create_tensor(size, memory_format, dtype, device):
5206- t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device)
5207- t = t.random_(1, 10)
5208- return t
5209-
5210- def _get_ref_device(backend: str , device: str):
5211- # If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend
5212- if backend in ("NHWC", "NHWC3D", "NCHW", "NCHW3D"):
5213- return device
5214- if backend == "native":
5215- return "cuda"
5216- if backend == "cpu":
5217- return "cpu"
5218- else:
5219- raise ValueError("Unknown backend")
5220-
5221- def _get_backend_memory_format(backend: str, memory_format: torch.memory_format) -> torch.memory_format:
5222- # If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg
5223- if backend == "NHWC":
5224- return torch.channels_last
5225- if backend == "NHWC3D":
5226- return torch.channels_last_3d
5227- if backend in ("NCHW", "NCHW3D"):
5228- return torch.contiguous_format
5229- if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d):
5230- return memory_format
5231- raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}")
5232-
5233- def _get_memory_format(t: torch.Tensor) -> torch.memory_format:
5234- if t.is_contiguous(memory_format=torch.contiguous_format):
5235- return torch.contiguous_format
5236- if t.is_contiguous(memory_format=torch.channels_last):
5237- return torch.channels_last
5238- if t.is_contiguous(memory_format=torch.channels_last_3d):
5239- return torch.channels_last_3d
5240- return ValueError("Unsupported memory_format")
5241-
5242- def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format:
5243- if memory_format_name == "NHWC":
5244- return torch.channels_last
5245- elif memory_format_name == "NHWC3D":
5246- return torch.channels_last_3d
5247- elif memory_format_name in ("NCHW", "NCHW3D"):
5248- return torch.contiguous_format
5249- return ValueError("Unsupported memory_format")
5250-
5251- def _create_backend(inp: torch.Tensor, mixed: bool = False):
5252- if inp.dim() == 4:
5253- return nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5254- else:
5255- return nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype)
5256-
5257- def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend):
5258- mod = _create_backend(inp, mixed).train()
5259- mod.weight.data.uniform_()
5260- mod.bias.data.uniform_()
5261-
5262- ref_mod = _create_backend(ref_inp, mixed).train()
5263- ref_mod.load_state_dict(mod.state_dict())
5264-
5265- out = mod(inp)
5266- out.backward(grad)
5267-
5268- with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext():
5269- ref_out = ref_mod(ref_inp)
5270- ref_out.backward(ref_grad)
5271-
5272- self.assertTrue(out.is_contiguous(memory_format=_get_memory_format(inp)))
5273- self.assertTrue(ref_out.is_contiguous(memory_format=_get_memory_format(ref_inp)))
5274- self.assertEqual(out, ref_out)
5275- self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
5276- self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
5277- self.assertEqual(mod.running_mean, ref_mod.running_mean)
5278- self.assertEqual(mod.running_var, ref_mod.running_var)
5279- self.assertEqual(inp.grad, ref_inp.grad)
5280-
5281- def _train(memory_format_name, ref_backend, mixed, dtype):
5282- memory_format = _get_memory_format_from_name(memory_format_name)
5283-
5284- ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
5285- ref_device = _get_ref_device(ref_backend, device="cuda")
5286-
5287- size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2)
5288- inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_()
5289- grad = _create_tensor(size, memory_format, dtype, device="cuda")
5290- ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_()
5291- ref_grad = grad.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
5292-
5293- _test_batchnorm_train(inp=inp, grad=grad, mixed=mixed,
5294- ref_inp=ref_inp, ref_grad=ref_grad, ref_backend=ref_backend)
5295-
5296- def _inference(memory_format_name, ref_backend, mixed, dtype):
5297- memory_format = _get_memory_format_from_name(memory_format_name)
5298- ref_memory_format = _get_backend_memory_format(ref_backend, memory_format)
5299- ref_device = _get_ref_device(ref_backend, device="cuda")
5300-
5301- size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50)
5302- inp = _create_tensor(size, memory_format, dtype, device="cuda")
5303- ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device)
5304- mod = _create_backend(inp, mixed).eval()
5305- ref_mod = _create_backend(ref_inp, mixed).eval()
5306-
5307- out = mod(inp)
5308- with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext():
5309- ref_out = ref_mod(ref_inp)
5310- self.assertEqual(out, ref_out)
5311-
5312- if mode == "train":
5313- _train(memory_format, ref_backend, mixed, dtype)
5314- else:
5315- _inference(memory_format, ref_backend, mixed, dtype)
5316-
53175148 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
53185149 def test_batchnorm_nhwc_cuda(self):
53195150 for dtype in (torch.half, torch.float):
0 commit comments