Skip to content

Commit 2e50ad3

Browse files
dnikolaev-amdAMD AMD
authored andcommitted
Cherry-picked commit with merge conflict
1 parent e7df144 commit 2e50ad3

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

test/test_nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5174,6 +5174,7 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
51745174
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
51755175
)
51765176
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
5177+
<<<<<<< HEAD
51775178
if torch.version.cuda:
51785179
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
51795180
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
@@ -5182,6 +5183,10 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
51825183
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
51835184
self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
51845185

5186+
=======
5187+
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
5188+
self.skipTest("3D float16 NCHW train failed on CUDA and ROCm due to Native batchnorm accuracy issue SWDEV-541024")
5189+
>>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
51855190
if torch.version.hip:
51865191
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
51875192
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
@@ -5197,10 +5202,13 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
51975202
# https://github.com/pytorch/pytorch/issues/156513
51985203
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
51995204

5205+
<<<<<<< HEAD
52005206
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
52015207
and _get_torch_rocm_version() < (7, 0):
52025208
self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
52035209

5210+
=======
5211+
>>>>>>> 4eaa5bf23b ([rocm7.0_internal_testing] skip 3D NCHW FP16 batchnorm test due to Native accuracy issue (#2370))
52045212
if dims == 3 and memory_format in ("NHWC", "NCHW"):
52055213
memory_format = memory_format + "3D"
52065214

0 commit comments

Comments
 (0)