@@ -115,11 +115,18 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115115 Constant zero (dataType, 0 );
116116 Tensor save_mean, save_var;
117117
118+ auto tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
119+ miopenTuningPolicy_t previous_policy;
120+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
121+ if (tuning_policy != previous_policy) {
122+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
123+ }
124+
118125 if (training) {
119126 int64_t num_features = input_t .size (1 );
120127 save_mean = at::empty ({ num_features }, weight_t .options ());
121128 save_var = at::empty ({ num_features }, weight_t .options ());
122- MIOPEN_CHECK ( miopenBatchNormalizationForwardTraining (
129+ auto status = miopenBatchNormalizationForwardTraining (
123130 handle, mode, &one, &zero,
124131 idesc.desc (), input->const_data_ptr (),
125132 idesc.desc (), output->data_ptr (),
@@ -134,11 +141,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
134141 at::maybe_data_ptr (running_var),
135142 epsilon,
136143 save_mean.mutable_data_ptr (),
137- save_var.mutable_data_ptr ()));
144+ save_var.mutable_data_ptr ());
145+ if (tuning_policy != previous_policy) {
146+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
147+ }
148+ MIOPEN_CHECK (status);
138149 } else {
139150 save_mean = at::empty ({0 }, weight_t .options ());
140151 save_var = at::empty ({0 }, weight_t .options ());
141- MIOPEN_CHECK ( miopenBatchNormalizationForwardInference (
152+ auto status = miopenBatchNormalizationForwardInference (
142153 handle, mode, &one, &zero,
143154 idesc.desc (), input->const_data_ptr (),
144155 idesc.desc (), output->data_ptr (),
@@ -150,7 +161,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
150161 const_cast <void *>(bias->const_data_ptr ()),
151162 running_mean->data_ptr (),
152163 running_var->data_ptr (),
153- epsilon));
164+ epsilon);
165+ if (tuning_policy != previous_policy) {
166+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
167+ }
168+ MIOPEN_CHECK (status);
154169 }
155170
156171 // save_mean and save_var can be undefined
@@ -223,7 +238,13 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223238 Constant one (dataType, 1 );
224239 Constant zero (dataType, 0 );
225240
226- MIOPEN_CHECK (miopenBatchNormalizationBackward (
241+ auto tuning_policy = at::globalContext ().benchmarkCuDNN () ? miopenTuningPolicySearch : miopenTuningPolicyNone;
242+ miopenTuningPolicy_t previous_policy;
243+ MIOPEN_CHECK (miopenGetTuningPolicy (handle, &previous_policy));
244+ if (tuning_policy != previous_policy) {
245+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, tuning_policy));
246+ }
247+ auto status = miopenBatchNormalizationBackward (
227248 handle, mode, &one, &zero, &one, &zero,
228249 idesc.desc (), input->const_data_ptr (),
229250 idesc.desc (), grad_output->const_data_ptr (),
@@ -233,7 +254,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
233254 grad_bias_t .data_ptr (),
234255 epsilon,
235256 save_mean->const_data_ptr (),
236- save_var->const_data_ptr ()));
257+ save_var->const_data_ptr ());
258+ if (tuning_policy != previous_policy) {
259+ MIOPEN_CHECK (miopenSetTuningPolicy (handle, previous_policy));
260+ }
261+ MIOPEN_CHECK (status);
237262
238263 return std::tuple<Tensor,Tensor,Tensor>{grad_input_t , grad_weight_t , grad_bias_t };
239264}
0 commit comments