@@ -5174,25 +5174,22 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51745174 name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51755175 )
51765176 def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5177- <<<<<<< HEAD
51785177 if torch.version.cuda:
51795178 if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5180- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
5181- self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
5182- "https://github.com/pytorch/pytorch/issues/156513")
5183- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5184- self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
5179+ "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5180+ "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5181+ "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5182+ "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16") :
5183+ self.skipTest("Failed on CUDA")
51855184
5186- =======
5187- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5188- self.skipTest("3D float16 NCHW train failed on CUDA and ROCm due to Native batchnorm accuracy issue SWDEV-541024")
5189- >>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
51905185 if torch.version.hip:
51915186 if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5192- "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
5187+ "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5188+ "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5189+ "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \
51935190 and _get_torch_rocm_version() < (6, 4):
51945191 # NCHW bfloat16 path uses native kernels for rocm<=6.3
5195- # train failed on rocm<=6.3 due to native tolerance issue
5192+ # train failed on rocm<=6.3 due to native accuracy issue
51965193 # https://github.com/pytorch/pytorch/issues/156513
51975194 self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
51985195
@@ -5202,13 +5199,9 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
52025199 # https://github.com/pytorch/pytorch/issues/156513
52035200 self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
52045201
5205- <<<<<<< HEAD
5206- if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5207- and _get_torch_rocm_version() < (7, 0):
5208- self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
5202+ if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5203+ self.skipTest("3D float16 NCHW train failed on ROCm")
52095204
5210- =======
5211- >>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
52125205 if dims == 3 and memory_format in ("NHWC", "NCHW"):
52135206 memory_format = memory_format + "3D"
52145207
0 commit comments