diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 0c122c9e13d4d..866baa4a43459 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -58,6 +58,8 @@ Tensor expandScale(const Tensor& t, int64_t dim) { } // namespace +static const bool miopen_tuning_available = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() >= 30501 : false; + std::tuple miopen_batch_norm( const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, bool training, double exponential_average_factor, double epsilon) @@ -115,6 +117,16 @@ std::tuple miopen_batch_norm( Constant zero(dataType, 0); Tensor save_mean, save_var; + miopenTuningPolicy_t tuning_policy; + miopenTuningPolicy_t previous_policy; + if (miopen_tuning_available) { + tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone; + MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy)); + if (tuning_policy != previous_policy) { + MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy)); + } + } + if (training) { int64_t num_features = input_t.size(1); save_mean = at::empty({ num_features }, weight_t.options()); @@ -135,6 +147,9 @@ std::tuple miopen_batch_norm( epsilon, save_mean.mutable_data_ptr(), save_var.mutable_data_ptr())); + if (miopen_tuning_available && tuning_policy != previous_policy) { + MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy)); + } } else { save_mean = at::empty({0}, weight_t.options()); save_var = at::empty({0}, weight_t.options()); @@ -151,6 +166,9 @@ std::tuple miopen_batch_norm( running_mean->data_ptr(), running_var->data_ptr(), epsilon)); + if (miopen_tuning_available && tuning_policy != previous_policy) { + MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy)); + } } // save_mean and save_var can be undefined @@ -223,6 +241,15 @@ std::tuple miopen_batch_norm_backward( Constant one(dataType, 1); Constant zero(dataType, 0); + miopenTuningPolicy_t tuning_policy; + miopenTuningPolicy_t previous_policy; + if (miopen_tuning_available) { + tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone; + MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy)); + if (tuning_policy != previous_policy) { + MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy)); + } + } MIOPEN_CHECK(miopenBatchNormalizationBackward( handle, mode, &one, &zero, &one, &zero, idesc.desc(), input->const_data_ptr(), @@ -234,6 +261,9 @@ std::tuple miopen_batch_norm_backward( epsilon, save_mean->const_data_ptr(), save_var->const_data_ptr())); + if (miopen_tuning_available && tuning_policy != previous_policy) { + MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy)); + } return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; }