@@ -115,6 +115,13 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115115 Constant zero (dataType, 0 );
116116 Tensor save_mean, save_var;
117117
118+ auto tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
119+ miopenTuningPolicy_t previous_policy;
120+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
121+ if (tuning_policy != previous_policy) {
122+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
123+ }
124+
118125 if (training) {
119126 int64_t num_features = input_t .size (1 );
120127 save_mean = at::empty ({ num_features }, weight_t .options ());
@@ -135,6 +142,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
135142 epsilon,
136143 save_mean.mutable_data_ptr (),
137144 save_var.mutable_data_ptr ()));
145+ if (tuning_policy != previous_policy) {
146+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
147+ }
138148 } else {
139149 save_mean = at::empty ({0 }, weight_t .options ());
140150 save_var = at::empty ({0 }, weight_t .options ());
@@ -151,6 +161,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
151161 running_mean->data_ptr (),
152162 running_var->data_ptr (),
153163 epsilon));
164+ if (tuning_policy != previous_policy) {
165+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
166+ }
154167 }
155168
156169 // save_mean and save_var can be undefined
@@ -223,6 +236,12 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223236 Constant one (dataType, 1 );
224237 Constant zero (dataType, 0 );
225238
239+ auto tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
240+ miopenTuningPolicy_t previous_policy;
241+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
242+ if (tuning_policy != previous_policy) {
243+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
244+ }
226245 MIOPEN_CHECK (miopenBatchNormalizationBackward (
227246 handle, mode, &one, &zero, &one, &zero,
228247 idesc.desc (), input->const_data_ptr (),
@@ -234,6 +253,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
234253 epsilon,
235254 save_mean->const_data_ptr (),
236255 save_var->const_data_ptr ()));
256+ if (tuning_policy != previous_policy) {
257+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
258+ }
237259
238260 return std::tuple<Tensor,Tensor,Tensor>{grad_input_t , grad_weight_t , grad_bias_t };
239261}
0 commit comments