@@ -5176,18 +5176,20 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51765176 def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
51775177 if torch.version.cuda:
51785178 if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5179- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
5180- self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
5181- "https://github.com/pytorch/pytorch/issues/156513")
5182- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5183- self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
5179+ "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5180+ "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5181+ "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5182+ "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16") :
5183+ self.skipTest("Failed on CUDA")
51845184
51855185 if torch.version.hip:
51865186 if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5187- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
5187+ "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5188+ "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5189+ "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \
51885190 and _get_torch_rocm_version() < (6, 4):
51895191 # NCHW bfloat16 path uses native kernels for rocm<=6.3
5190- # train failed on rocm<=6.3 due to native tolerance issue
5192+ # train failed on rocm<=6.3 due to native accuracy issue
51915193 # https://github.com/pytorch/pytorch/issues/156513
51925194 self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
51935195
@@ -5197,9 +5199,8 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
51975199 # https://github.com/pytorch/pytorch/issues/156513
51985200 self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
51995201
5200- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5201- and _get_torch_rocm_version() < (7, 0):
5202- self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
5202+ if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5203+ self.skipTest("3D float16 NCHW train failed on ROCm")
52035204
52045205 if dims == 3 and memory_format in ("NHWC", "NCHW"):
52055206 memory_format = memory_format + "3D"
0 commit comments