Skip to content

Commit e0afc3a

Browse files
[release/2.7] enable NHWC batchnorm by default on ROCm7.0+ (#2180)
NHWC batchnorm enabled by default if ROCm>=7.0
1 parent 92d6dd8 commit e0afc3a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ BatchNormBackend _select_batch_norm_backend(
524524
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen
525525
// See #64427
526526
// non static variable is used to be able to change environment variable in runtime for testing
527-
bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(false);
527+
// enabled by default for ROCm >= 7.0.0
528+
bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(ROCM_VERSION >= 70000);
528529

529530
if (
530531
input.is_cuda()

0 commit comments

Comments
 (0)