Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Tensor expandScale(const Tensor& t, int64_t dim) {

} // namespace

static const bool miopen_tuning_available = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() >= 30501 : false;

std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt, const std::optional<Tensor>& running_mean_t_opt, const std::optional<Tensor>& running_var_t_opt,
bool training, double exponential_average_factor, double epsilon)
Expand Down Expand Up @@ -115,6 +117,16 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
Constant zero(dataType, 0);
Tensor save_mean, save_var;

miopenTuningPolicy_t tuning_policy;
miopenTuningPolicy_t previous_policy;
if (miopen_tuning_available) {
tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
if (tuning_policy != previous_policy) {
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
}
}

if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({ num_features }, weight_t.options());
Expand All @@ -135,6 +147,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
epsilon,
save_mean.mutable_data_ptr(),
save_var.mutable_data_ptr()));
if (miopen_tuning_available && tuning_policy != previous_policy) {
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
}
} else {
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
Expand All @@ -151,6 +166,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
running_mean->data_ptr(),
running_var->data_ptr(),
epsilon));
if (miopen_tuning_available && tuning_policy != previous_policy) {
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
}
}

// save_mean and save_var can be undefined
Expand Down Expand Up @@ -223,6 +241,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
Constant one(dataType, 1);
Constant zero(dataType, 0);

miopenTuningPolicy_t tuning_policy;
miopenTuningPolicy_t previous_policy;
if (miopen_tuning_available) {
tuning_policy = at::globalContext().benchmarkCuDNN() ? miopenTuningPolicySearch : miopenTuningPolicyNone;
MIOPEN_CHECK(miopenGetTuningPolicy(handle, &previous_policy));
if (tuning_policy != previous_policy) {
MIOPEN_CHECK(miopenSetTuningPolicy(handle, tuning_policy));
}
}
MIOPEN_CHECK(miopenBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &zero,
idesc.desc(), input->const_data_ptr(),
Expand All @@ -234,6 +261,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
epsilon,
save_mean->const_data_ptr(),
save_var->const_data_ptr()));
if (miopen_tuning_available && tuning_policy != previous_policy) {
MIOPEN_CHECK(miopenSetTuningPolicy(handle, previous_policy));
}

return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
}
Expand Down