Skip to content

Commit c838f18

Browse files
committed
miopen tune api for batchnorm
1 parent 5fc1aea commit c838f18

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115115
Constant zero(dataType, 0);
116116
Tensor save_mean, save_var;
117117

118+
auto tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
119+
miopenTuningPolicy_t previous_policy;
120+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
121+
if (tuning_policy != previous_policy) {
122+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
123+
}
124+
118125
if (training) {
119126
int64_t num_features = input_t.size(1);
120127
save_mean = at::empty({ num_features }, weight_t.options());
@@ -135,6 +142,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
135142
epsilon,
136143
save_mean.mutable_data_ptr(),
137144
save_var.mutable_data_ptr()));
145+
if (tuning_policy != previous_policy) {
146+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
147+
}
138148
} else {
139149
save_mean = at::empty({0}, weight_t.options());
140150
save_var = at::empty({0}, weight_t.options());
@@ -151,6 +161,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
151161
running_mean->data_ptr(),
152162
running_var->data_ptr(),
153163
epsilon));
164+
if (tuning_policy != previous_policy) {
165+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
166+
}
154167
}
155168

156169
// save_mean and save_var can be undefined
@@ -223,6 +236,12 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223236
Constant one(dataType, 1);
224237
Constant zero(dataType, 0);
225238

239+
auto tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
240+
miopenTuningPolicy_t previous_policy;
241+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
242+
if (tuning_policy != previous_policy) {
243+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
244+
}
226245
MIOPEN_CHECK(miopenBatchNormalizationBackward(
227246
handle, mode, &one, &zero, &one, &zero,
228247
idesc.desc(), input->const_data_ptr(),
@@ -234,6 +253,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
234253
epsilon,
235254
save_mean->const_data_ptr(),
236255
save_var->const_data_ptr()));
256+
if (tuning_policy != previous_policy) {
257+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
258+
}
237259

238260
return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
239261
}

0 commit comments

Comments
 (0)