@@ -233,29 +233,38 @@ void CastFP64toFP32Kernel(const Context& dev_ctx,
233
233
}
234
234
235
235
template <typename T, typename Context>
236
- void AdamKernel (const Context& dev_ctx,
237
- const phi::DenseTensor& param,
238
- const phi::DenseTensor& grad,
239
- const phi::DenseTensor& learning_rate,
240
- const phi::DenseTensor& moment1,
241
- const phi::DenseTensor& moment2,
242
- const phi::DenseTensor& beta1_pow_in,
243
- const phi::DenseTensor& beta2_pow_in,
244
- const paddle::optional<phi::DenseTensor>& master_param,
245
- const paddle::optional<phi::DenseTensor>& skip_update,
246
- const phi::Scalar& beta1_in,
247
- const phi::Scalar& beta2_in,
248
- const phi::Scalar& epsilon_in,
249
- bool lazy_mode,
250
- int64_t min_row_size_to_use_multithread,
251
- bool multi_precision,
252
- bool use_global_beta_pow,
253
- phi::DenseTensor* param_out,
254
- phi::DenseTensor* moment1_out,
255
- phi::DenseTensor* moment2_out,
256
- phi::DenseTensor* beta1_pow_out,
257
- phi::DenseTensor* beta2_pow_out,
258
- phi::DenseTensor* master_param_out) {
236
+ void AdamKernel (
237
+ const Context& dev_ctx,
238
+ const phi::DenseTensor& param,
239
+ const phi::DenseTensor& grad,
240
+ const phi::DenseTensor& learning_rate,
241
+ const phi::DenseTensor& moment1,
242
+ const phi::DenseTensor& moment2,
243
+ const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
244
+ const phi::DenseTensor& beta1_pow_in,
245
+ const phi::DenseTensor& beta2_pow_in,
246
+ const paddle::optional<phi::DenseTensor>& master_param,
247
+ const paddle::optional<phi::DenseTensor>& skip_update,
248
+ const phi::Scalar& beta1_in,
249
+ const phi::Scalar& beta2_in,
250
+ const phi::Scalar& epsilon_in,
251
+ bool lazy_mode,
252
+ int64_t min_row_size_to_use_multithread,
253
+ bool multi_precision,
254
+ bool use_global_beta_pow,
255
+ bool amsgrad, // UNUSED
256
+ phi::DenseTensor* param_out,
257
+ phi::DenseTensor* moment1_out,
258
+ phi::DenseTensor* moment2_out,
259
+ phi::DenseTensor* moment2_max_out, // UNUSED
260
+ phi::DenseTensor* beta1_pow_out,
261
+ phi::DenseTensor* beta2_pow_out,
262
+ phi::DenseTensor* master_param_out) {
263
+ PADDLE_ENFORCE_NE (
264
+ amsgrad,
265
+ true ,
266
+ phi::errors::Unimplemented (" Operation amsgrad is not supported yet." ));
267
+
259
268
bool skip_update_ = false ;
260
269
if (skip_update.is_initialized ()) {
261
270
PADDLE_ENFORCE_EQ (skip_update->numel (),
@@ -358,32 +367,41 @@ void AdamKernel(const Context& dev_ctx,
358
367
}
359
368
360
369
template <typename T, typename Context>
361
- void AdamwKernel (const Context& dev_ctx,
362
- const phi::DenseTensor& param,
363
- const phi::DenseTensor& grad,
364
- const phi::DenseTensor& learning_rate,
365
- const phi::DenseTensor& moment1,
366
- const phi::DenseTensor& moment2,
367
- const phi::DenseTensor& beta1_pow,
368
- const phi::DenseTensor& beta2_pow,
369
- const paddle::optional<phi::DenseTensor>& master_param,
370
- const paddle::optional<phi::DenseTensor>& skip_update,
371
- const phi::Scalar& beta1,
372
- const phi::Scalar& beta2,
373
- const phi::Scalar& epsilon,
374
- float lr_ratio,
375
- float coeff,
376
- bool with_decay,
377
- bool lazy_mode,
378
- int64_t min_row_size_to_use_multithread,
379
- bool multi_precision,
380
- bool use_global_beta_pow,
381
- phi::DenseTensor* param_out,
382
- phi::DenseTensor* moment1_out,
383
- phi::DenseTensor* moment2_out,
384
- phi::DenseTensor* beta1_pow_out,
385
- phi::DenseTensor* beta2_pow_out,
386
- phi::DenseTensor* master_param_outs) {
370
+ void AdamwKernel (
371
+ const Context& dev_ctx,
372
+ const phi::DenseTensor& param,
373
+ const phi::DenseTensor& grad,
374
+ const phi::DenseTensor& learning_rate,
375
+ const phi::DenseTensor& moment1,
376
+ const phi::DenseTensor& moment2,
377
+ const paddle::optional<phi::DenseTensor>& moment2_max, // UNUSED
378
+ const phi::DenseTensor& beta1_pow,
379
+ const phi::DenseTensor& beta2_pow,
380
+ const paddle::optional<phi::DenseTensor>& master_param,
381
+ const paddle::optional<phi::DenseTensor>& skip_update,
382
+ const phi::Scalar& beta1,
383
+ const phi::Scalar& beta2,
384
+ const phi::Scalar& epsilon,
385
+ float lr_ratio,
386
+ float coeff,
387
+ bool with_decay,
388
+ bool lazy_mode,
389
+ int64_t min_row_size_to_use_multithread,
390
+ bool multi_precision,
391
+ bool use_global_beta_pow,
392
+ bool amsgrad, // UNUSED
393
+ phi::DenseTensor* param_out,
394
+ phi::DenseTensor* moment1_out,
395
+ phi::DenseTensor* moment2_out,
396
+ phi::DenseTensor* moment2_max_out, // UNUSED
397
+ phi::DenseTensor* beta1_pow_out,
398
+ phi::DenseTensor* beta2_pow_out,
399
+ phi::DenseTensor* master_param_outs) {
400
+ PADDLE_ENFORCE_NE (
401
+ amsgrad,
402
+ true ,
403
+ phi::errors::Unimplemented (" Operation amsgrad is not supported yet." ));
404
+
387
405
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
388
406
389
407
bool skip_update_ = false ;
@@ -514,18 +532,19 @@ PD_REGISTER_PLUGIN_KERNEL(adam,
514
532
float ,
515
533
double ) {
516
534
// Skip beta1_pow, beta2_pow, skip_update data transform
517
- kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND);
518
535
kernel->InputAt (6 ).SetBackend (phi::Backend::ALL_BACKEND);
519
- kernel->InputAt (8 ).SetBackend (phi::Backend::ALL_BACKEND);
536
+ kernel->InputAt (7 ).SetBackend (phi::Backend::ALL_BACKEND);
537
+ kernel->InputAt (9 ).SetBackend (phi::Backend::ALL_BACKEND);
520
538
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
521
539
kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
522
540
kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
523
541
kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
524
542
kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
525
543
kernel->OutputAt (5 ).SetDataType (phi::DataType::FLOAT32);
544
+ kernel->OutputAt (6 ).SetDataType (phi::DataType::FLOAT32);
526
545
}
527
- kernel->OutputAt (3 ).SetBackend (phi::Backend::UNDEFINED);
528
546
kernel->OutputAt (4 ).SetBackend (phi::Backend::UNDEFINED);
547
+ kernel->OutputAt (5 ).SetBackend (phi::Backend::UNDEFINED);
529
548
}
530
549
531
550
PD_REGISTER_PLUGIN_KERNEL (adamw,
@@ -537,16 +556,17 @@ PD_REGISTER_PLUGIN_KERNEL(adamw,
537
556
float ,
538
557
double ) {
539
558
// Skip beta1_pow, beta2_pow, skip_update data transform
540
- kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND);
541
559
kernel->InputAt (6 ).SetBackend (phi::Backend::ALL_BACKEND);
542
- kernel->InputAt (8 ).SetBackend (phi::Backend::ALL_BACKEND);
560
+ kernel->InputAt (7 ).SetBackend (phi::Backend::ALL_BACKEND);
561
+ kernel->InputAt (9 ).SetBackend (phi::Backend::ALL_BACKEND);
543
562
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
544
563
kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
545
564
kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
546
565
kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
547
566
kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
548
567
kernel->OutputAt (5 ).SetDataType (phi::DataType::FLOAT32);
568
+ kernel->OutputAt (6 ).SetDataType (phi::DataType::FLOAT32);
549
569
}
550
- kernel->OutputAt (3 ).SetBackend (phi::Backend::UNDEFINED);
551
570
kernel->OutputAt (4 ).SetBackend (phi::Backend::UNDEFINED);
571
+ kernel->OutputAt (5 ).SetBackend (phi::Backend::UNDEFINED);
552
572
}
0 commit comments