Skip to content

Commit cc48438

Browse files
Update test_nn.py
resolve conflict
1 parent 2e50ad3 commit cc48438

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

test/test_nn.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5174,25 +5174,22 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51745174
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51755175
)
51765176
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5177-
<<<<<<< HEAD
51785177
if torch.version.cuda:
51795178
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5180-
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
5181-
self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
5182-
"https://github.com/pytorch/pytorch/issues/156513")
5183-
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5184-
self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
5179+
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5180+
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5181+
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
5182+
"test_batchnorm_3D_train_NCHW_vs_native_mixed_float16"):
5183+
self.skipTest("Failed on CUDA")
51855184

5186-
=======
5187-
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5188-
self.skipTest("3D float16 NCHW train failed on CUDA and ROCm due to Native batchnorm accuracy issue SWDEV-541024")
5189-
>>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
51905185
if torch.version.hip:
51915186
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
5192-
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
5187+
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
5188+
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
5189+
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \
51935190
and _get_torch_rocm_version() < (6, 4):
51945191
# NCHW bfloat16 path uses native kernels for rocm<=6.3
5195-
# train failed on rocm<=6.3 due to native tolerance issue
5192+
# train failed on rocm<=6.3 due to native accuracy issue
51965193
# https://github.com/pytorch/pytorch/issues/156513
51975194
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
51985195

@@ -5202,13 +5199,9 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
52025199
# https://github.com/pytorch/pytorch/issues/156513
52035200
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
52045201

5205-
<<<<<<< HEAD
5206-
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
5207-
and _get_torch_rocm_version() < (7, 0):
5208-
self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
5202+
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5203+
self.skipTest("3D float16 NCHW train failed on ROCm")
52095204

5210-
=======
5211-
>>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
52125205
if dims == 3 and memory_format in ("NHWC", "NCHW"):
52135206
memory_format = memory_format + "3D"
52145207

0 commit comments

Comments
 (0)