Skip to content

Commit 674b218

Browse files
committed
miopen tune api for batchnorm
1 parent 5fc1aea commit 674b218

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,18 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115115
Constant zero(dataType, 0);
116116
Tensor save_mean, save_var;
117117

118+
auto tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
119+
miopenTuningPolicy_t previous_policy;
120+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
121+
if (tuning_policy != previous_policy) {
122+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
123+
}
124+
118125
if (training) {
119126
int64_t num_features = input_t.size(1);
120127
save_mean = at::empty({ num_features }, weight_t.options());
121128
save_var = at::empty({ num_features }, weight_t.options());
122-
MIOPEN_CHECK(miopenBatchNormalizationForwardTraining(
129+
auto status = miopenBatchNormalizationForwardTraining(
123130
handle, mode, &one, &zero,
124131
idesc.desc(), input->const_data_ptr(),
125132
idesc.desc(), output->data_ptr(),
@@ -134,11 +141,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
134141
at::maybe_data_ptr(running_var),
135142
epsilon,
136143
save_mean.mutable_data_ptr(),
137-
save_var.mutable_data_ptr()));
144+
save_var.mutable_data_ptr());
145+
if (tuning_policy != previous_policy) {
146+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
147+
}
148+
MIOPEN_CHECK(status);
138149
} else {
139150
save_mean = at::empty({0}, weight_t.options());
140151
save_var = at::empty({0}, weight_t.options());
141-
MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
152+
auto status = miopenBatchNormalizationForwardInference(
142153
handle, mode, &one, &zero,
143154
idesc.desc(), input->const_data_ptr(),
144155
idesc.desc(), output->data_ptr(),
@@ -150,7 +161,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
150161
const_cast<void*>(bias->const_data_ptr()),
151162
running_mean->data_ptr(),
152163
running_var->data_ptr(),
153-
epsilon));
164+
epsilon);
165+
if (tuning_policy != previous_policy) {
166+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
167+
}
168+
MIOPEN_CHECK(status);
154169
}
155170

156171
// save_mean and save_var can be undefined
@@ -223,7 +238,13 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223238
Constant one(dataType, 1);
224239
Constant zero(dataType, 0);
225240

226-
MIOPEN_CHECK(miopenBatchNormalizationBackward(
241+
auto tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
242+
miopenTuningPolicy_t previous_policy;
243+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
244+
if (tuning_policy != previous_policy) {
245+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
246+
}
247+
auto status = miopenBatchNormalizationBackward(
227248
handle, mode, &one, &zero, &one, &zero,
228249
idesc.desc(), input->const_data_ptr(),
229250
idesc.desc(), grad_output->const_data_ptr(),
@@ -233,7 +254,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
233254
grad_bias_t.data_ptr(),
234255
epsilon,
235256
save_mean->const_data_ptr(),
236-
save_var->const_data_ptr()));
257+
save_var->const_data_ptr());
258+
if (tuning_policy != previous_policy) {
259+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
260+
}
261+
MIOPEN_CHECK(status);
237262

238263
return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
239264
}

0 commit comments

Comments
 (0)