Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions backends/mlu/kernels/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ void AdamKernel(const Context& dev_ctx,
"value is:%d.",
beta2_pow_out->numel()));

const phi::DenseTensor* beta1_tensor = nullptr;
const phi::DenseTensor* beta2_tensor = nullptr;
const phi::DenseTensor* epsilon_tensor = nullptr;
Tensor beta1_tensor;

phi::DenseTensor beta1_tmp;
phi::DenseTensor beta2_tmp;
Expand All @@ -128,19 +126,30 @@ void AdamKernel(const Context& dev_ctx,
epsilon_tmp.Resize({1});

MPDType beta1 = beta1_in.to<MPDType>();
dev_ctx.template Alloc<MPDType>(&beta1_tmp);
FillMLUTensorWithHostValue<MPDType>(dev_ctx, beta1, &beta1_tmp);
beta1_tensor = &beta1_tmp;

MPDType beta2 = beta2_in.to<MPDType>();
dev_ctx.template Alloc<MPDType>(&beta2_tmp);
FillMLUTensorWithHostValue<MPDType>(dev_ctx, beta2, &beta2_tmp);
beta2_tensor = &beta2_tmp;

MPDType epsilon = epsilon_in.to<MPDType>();
dev_ctx.template Alloc<MPDType>(&epsilon_tmp);
FillMLUTensorWithHostValue<MPDType>(dev_ctx, epsilon, &epsilon_tmp);
epsilon_tensor = &epsilon_tmp;

std::vector<MPDType> parameter_list;
parameter_list.push_back(beta1);
parameter_list.push_back(beta2);
parameter_list.push_back(epsilon);

Tensor dst;
dst.Resize({3});
auto dst_place = phi::CustomPlace();
C_Device_st device{dst_place.GetDeviceId()};
void* dst_ptr = dev_ctx.template Alloc<MPDType>(&dst);
auto src_ptr = static_cast<void*>(parameter_list.data());
MemCpyH2D(&device, dst_ptr, src_ptr, parameter_list.size() * sizeof(MPDType));

const void* beta1_tensor_ptr = nullptr;
const void* beta2_tensor_ptr = nullptr;
const void* epsilon_tensor_ptr = nullptr;
beta1_tensor_ptr = dst_ptr,
beta2_tensor_ptr = static_cast<char*>(dst_ptr) + sizeof(MPDType);
epsilon_tensor_ptr = static_cast<char*>(dst_ptr) + 2 * sizeof(MPDType);

Tensor t_param_in_out, t_grad;
t_param_in_out.Resize(param.dims());
Expand Down Expand Up @@ -198,11 +207,11 @@ void AdamKernel(const Context& dev_ctx,
grad_desc.get(),
GetBasePtr(&t_grad),
GetBasePtr(&learning_rate),
GetBasePtr(beta1_tensor),
GetBasePtr(beta2_tensor),
beta1_tensor_ptr,
beta2_tensor_ptr,
GetBasePtr(beta1_pow),
GetBasePtr(beta2_pow),
GetBasePtr(epsilon_tensor),
epsilon_tensor_ptr,
/*use_nesterov*/ false);

if (param.dtype() != phi::DataType::FLOAT32) {
Expand All @@ -221,7 +230,6 @@ void AdamKernel(const Context& dev_ctx,
param_out_desc.get(),
GetBasePtr(param_out));
}

if (!use_global_beta_pow) {
if (beta1_pow->place().GetType() == phi::AllocationType::CPU &&
beta2_pow->place().GetType() == phi::AllocationType::CPU) {
Expand All @@ -235,7 +243,10 @@ void AdamKernel(const Context& dev_ctx,
dev_ctx.template Alloc<MPDType>(beta1_pow_out);
dev_ctx.template Alloc<MPDType>(beta2_pow_out);

MLUCnnlTensorDesc beta1_desc(*beta1_tensor);
beta1_tensor.Resize({1});
MLUCnnlTensorDesc beta1_desc(
beta1_tensor, CNNL_LAYOUT_ARRAY, ToCnnlDataType<MPDType>());

MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL,
ToCnnlDataType<MPDType>(),
CNNL_NOT_PROPAGATE_NAN);
Expand All @@ -245,7 +256,7 @@ void AdamKernel(const Context& dev_ctx,
beta1_desc.get(),
GetBasePtr(beta1_pow),
beta1_desc.get(),
GetBasePtr(beta1_tensor),
beta1_tensor_ptr,
beta1_desc.get(),
GetBasePtr(beta1_pow_out),
ToCnnlDataType<MPDType>());
Expand All @@ -255,7 +266,7 @@ void AdamKernel(const Context& dev_ctx,
beta1_desc.get(),
GetBasePtr(beta2_pow),
beta1_desc.get(),
GetBasePtr(beta2_tensor),
beta2_tensor_ptr,
beta1_desc.get(),
GetBasePtr(beta2_pow_out),
ToCnnlDataType<MPDType>());
Expand Down
Loading