Skip to content

Commit c8372d5

Browse files
authored
【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad (#1484)
1 parent 9817aad commit c8372d5

File tree

5 files changed

+237
-91
lines changed

5 files changed

+237
-91
lines changed

backends/npu/kernels/adam_kernel.cc

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -233,29 +233,38 @@ void CastFP64toFP32Kernel(const Context& dev_ctx,
233233
}
234234

235235
template <typename T, typename Context>
236-
void AdamKernel(const Context& dev_ctx,
237-
const phi::DenseTensor& param,
238-
const phi::DenseTensor& grad,
239-
const phi::DenseTensor& learning_rate,
240-
const phi::DenseTensor& moment1,
241-
const phi::DenseTensor& moment2,
242-
const phi::DenseTensor& beta1_pow_in,
243-
const phi::DenseTensor& beta2_pow_in,
244-
const paddle::optional<phi::DenseTensor>& master_param,
245-
const paddle::optional<phi::DenseTensor>& skip_update,
246-
const phi::Scalar& beta1_in,
247-
const phi::Scalar& beta2_in,
248-
const phi::Scalar& epsilon_in,
249-
bool lazy_mode,
250-
int64_t min_row_size_to_use_multithread,
251-
bool multi_precision,
252-
bool use_global_beta_pow,
253-
phi::DenseTensor* param_out,
254-
phi::DenseTensor* moment1_out,
255-
phi::DenseTensor* moment2_out,
256-
phi::DenseTensor* beta1_pow_out,
257-
phi::DenseTensor* beta2_pow_out,
258-
phi::DenseTensor* master_param_out) {
236+
void AdamKernel(
237+
const Context& dev_ctx,
238+
const phi::DenseTensor& param,
239+
const phi::DenseTensor& grad,
240+
const phi::DenseTensor& learning_rate,
241+
const phi::DenseTensor& moment1,
242+
const phi::DenseTensor& moment2,
243+
const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
244+
const phi::DenseTensor& beta1_pow_in,
245+
const phi::DenseTensor& beta2_pow_in,
246+
const paddle::optional<phi::DenseTensor>& master_param,
247+
const paddle::optional<phi::DenseTensor>& skip_update,
248+
const phi::Scalar& beta1_in,
249+
const phi::Scalar& beta2_in,
250+
const phi::Scalar& epsilon_in,
251+
bool lazy_mode,
252+
int64_t min_row_size_to_use_multithread,
253+
bool multi_precision,
254+
bool use_global_beta_pow,
255+
bool amsgrad, // UNUSED
256+
phi::DenseTensor* param_out,
257+
phi::DenseTensor* moment1_out,
258+
phi::DenseTensor* moment2_out,
259+
phi::DenseTensor* moment2_max_out, // UNUSED
260+
phi::DenseTensor* beta1_pow_out,
261+
phi::DenseTensor* beta2_pow_out,
262+
phi::DenseTensor* master_param_out) {
263+
PADDLE_ENFORCE_NE(
264+
amsgrad,
265+
true,
266+
phi::errors::Unimplemented("Operation amsgrad is not supported yet."));
267+
259268
bool skip_update_ = false;
260269
if (skip_update.is_initialized()) {
261270
PADDLE_ENFORCE_EQ(skip_update->numel(),
@@ -358,32 +367,41 @@ void AdamKernel(const Context& dev_ctx,
358367
}
359368

360369
template <typename T, typename Context>
361-
void AdamwKernel(const Context& dev_ctx,
362-
const phi::DenseTensor& param,
363-
const phi::DenseTensor& grad,
364-
const phi::DenseTensor& learning_rate,
365-
const phi::DenseTensor& moment1,
366-
const phi::DenseTensor& moment2,
367-
const phi::DenseTensor& beta1_pow,
368-
const phi::DenseTensor& beta2_pow,
369-
const paddle::optional<phi::DenseTensor>& master_param,
370-
const paddle::optional<phi::DenseTensor>& skip_update,
371-
const phi::Scalar& beta1,
372-
const phi::Scalar& beta2,
373-
const phi::Scalar& epsilon,
374-
float lr_ratio,
375-
float coeff,
376-
bool with_decay,
377-
bool lazy_mode,
378-
int64_t min_row_size_to_use_multithread,
379-
bool multi_precision,
380-
bool use_global_beta_pow,
381-
phi::DenseTensor* param_out,
382-
phi::DenseTensor* moment1_out,
383-
phi::DenseTensor* moment2_out,
384-
phi::DenseTensor* beta1_pow_out,
385-
phi::DenseTensor* beta2_pow_out,
386-
phi::DenseTensor* master_param_outs) {
370+
void AdamwKernel(
371+
const Context& dev_ctx,
372+
const phi::DenseTensor& param,
373+
const phi::DenseTensor& grad,
374+
const phi::DenseTensor& learning_rate,
375+
const phi::DenseTensor& moment1,
376+
const phi::DenseTensor& moment2,
377+
const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
378+
const phi::DenseTensor& beta1_pow,
379+
const phi::DenseTensor& beta2_pow,
380+
const paddle::optional<phi::DenseTensor>& master_param,
381+
const paddle::optional<phi::DenseTensor>& skip_update,
382+
const phi::Scalar& beta1,
383+
const phi::Scalar& beta2,
384+
const phi::Scalar& epsilon,
385+
float lr_ratio,
386+
float coeff,
387+
bool with_decay,
388+
bool lazy_mode,
389+
int64_t min_row_size_to_use_multithread,
390+
bool multi_precision,
391+
bool use_global_beta_pow,
392+
bool amsgrad, // UNUSED
393+
phi::DenseTensor* param_out,
394+
phi::DenseTensor* moment1_out,
395+
phi::DenseTensor* moment2_out,
396+
phi::DenseTensor* moment2_max_out, // UNUSED
397+
phi::DenseTensor* beta1_pow_out,
398+
phi::DenseTensor* beta2_pow_out,
399+
phi::DenseTensor* master_param_outs) {
400+
PADDLE_ENFORCE_NE(
401+
amsgrad,
402+
true,
403+
phi::errors::Unimplemented("Operation amsgrad is not supported yet."));
404+
387405
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
388406

389407
bool skip_update_ = false;
@@ -514,18 +532,19 @@ PD_REGISTER_PLUGIN_KERNEL(adam,
514532
float,
515533
double) {
516534
// Skip beta1_pow, beta2_pow, skip_update data transform
517-
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
518535
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
519-
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
536+
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
537+
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);
520538
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
521539
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
522540
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
523541
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
524542
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
525543
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
544+
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
526545
}
527-
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
528546
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
547+
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
529548
}
530549

531550
PD_REGISTER_PLUGIN_KERNEL(adamw,
@@ -537,16 +556,17 @@ PD_REGISTER_PLUGIN_KERNEL(adamw,
537556
float,
538557
double) {
539558
// Skip beta1_pow, beta2_pow, skip_update data transform
540-
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
541559
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
542-
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
560+
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
561+
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);
543562
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
544563
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
545564
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
546565
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
547566
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
548567
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
568+
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
549569
}
550-
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
551570
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
571+
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
552572
}

0 commit comments

Comments
 (0)