@@ -58,6 +58,8 @@ Tensor expandScale(const Tensor& t, int64_t dim) {
5858
5959} // namespace
6060
61+ static const bool miopen_tuning_available = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() >= 30501 : false ;
62+
6163std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm (
6264 const Tensor& input_t , const Tensor& weight_t , const std::optional<Tensor>& bias_t_opt, const std::optional<Tensor>& running_mean_t_opt, const std::optional<Tensor>& running_var_t_opt,
6365 bool training, double exponential_average_factor, double epsilon)
@@ -115,6 +117,16 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115117 Constant zero (dataType, 0 );
116118 Tensor save_mean, save_var;
117119
120+ miopenTuningPolicy_t tuning_policy;
121+ miopenTuningPolicy_t previous_policy;
122+ if (miopen_tuning_available) {
123+ tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
124+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
125+ if (tuning_policy != previous_policy) {
126+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
127+ }
128+ }
129+
118130 if (training) {
119131 int64_t num_features = input_t .size (1 );
120132 save_mean = at::empty ({ num_features }, weight_t .options ());
@@ -135,6 +147,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
135147 epsilon,
136148 save_mean.mutable_data_ptr (),
137149 save_var.mutable_data_ptr ()));
150+ if (miopen_tuning_available && tuning_policy != previous_policy) {
151+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
152+ }
138153 } else {
139154 save_mean = at::empty ({0 }, weight_t .options ());
140155 save_var = at::empty ({0 }, weight_t .options ());
@@ -151,6 +166,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
151166 running_mean->data_ptr (),
152167 running_var->data_ptr (),
153168 epsilon));
169+ if (miopen_tuning_available && tuning_policy != previous_policy) {
170+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
171+ }
154172 }
155173
156174 // save_mean and save_var can be undefined
@@ -223,6 +241,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223241 Constant one (dataType, 1 );
224242 Constant zero (dataType, 0 );
225243
244+ miopenTuningPolicy_t tuning_policy;
245+ miopenTuningPolicy_t previous_policy;
246+ if (miopen_tuning_available) {
247+ tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
248+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
249+ if (tuning_policy != previous_policy) {
250+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
251+ }
252+ }
226253 MIOPEN_CHECK (miopenBatchNormalizationBackward (
227254 handle, mode, &one, &zero, &one, &zero,
228255 idesc.desc (), input->const_data_ptr (),
@@ -234,6 +261,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
234261 epsilon,
235262 save_mean->const_data_ptr (),
236263 save_var->const_data_ptr ()));
264+ if (miopen_tuning_available && tuning_policy != previous_policy) {
265+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
266+ }
237267
238268 return std::tuple<Tensor,Tensor,Tensor>{grad_input_t , grad_weight_t , grad_bias_t };
239269}
0 commit comments