Skip to content

Commit 6c6aa7a

Browse files
committed
miopen tune api for batchnorm
1 parent 5fc1aea commit 6c6aa7a

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Tensor expandScale(const Tensor& t, int64_t dim) {
5858

5959
} // namespace
6060

61+
static const bool miopen_tuning_available = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() >= 30501 : false;
62+
6163
std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
6264
const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt, const std::optional<Tensor>& running_mean_t_opt, const std::optional<Tensor>& running_var_t_opt,
6365
bool training, double exponential_average_factor, double epsilon)
@@ -115,6 +117,16 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
115117
Constant zero(dataType, 0);
116118
Tensor save_mean, save_var;
117119

120+
miopenTuningPolicy_t tuning_policy;
121+
miopenTuningPolicy_t previous_policy;
122+
if (miopen_tuning_available) {
123+
tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
124+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
125+
if (tuning_policy != previous_policy) {
126+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
127+
}
128+
}
129+
118130
if (training) {
119131
int64_t num_features = input_t.size(1);
120132
save_mean = at::empty({ num_features }, weight_t.options());
@@ -135,6 +147,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
135147
epsilon,
136148
save_mean.mutable_data_ptr(),
137149
save_var.mutable_data_ptr()));
150+
if (miopen_tuning_available && tuning_policy != previous_policy) {
151+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
152+
}
138153
} else {
139154
save_mean = at::empty({0}, weight_t.options());
140155
save_var = at::empty({0}, weight_t.options());
@@ -151,6 +166,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
151166
running_mean->data_ptr(),
152167
running_var->data_ptr(),
153168
epsilon));
169+
if (miopen_tuning_available && tuning_policy != previous_policy) {
170+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
171+
}
154172
}
155173

156174
// save_mean and save_var can be undefined
@@ -223,6 +241,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
223241
Constant one(dataType, 1);
224242
Constant zero(dataType, 0);
225243

244+
miopenTuningPolicy_t tuning_policy;
245+
miopenTuningPolicy_t previous_policy;
246+
if (miopen_tuning_available) {
247+
tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
248+
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
249+
if (tuning_policy != previous_policy) {
250+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
251+
}
252+
}
226253
MIOPEN_CHECK(miopenBatchNormalizationBackward(
227254
handle, mode, &one, &zero, &one, &zero,
228255
idesc.desc(), input->const_data_ptr(),
@@ -234,6 +261,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
234261
epsilon,
235262
save_mean->const_data_ptr(),
236263
save_var->const_data_ptr()));
264+
if (miopen_tuning_available && tuning_policy != previous_policy) {
265+
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
266+
}
237267

238268
return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
239269
}

0 commit comments

Comments
 (0)