@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#pragma once
16
16
#include < math.h> // for sqrt in CPU and CUDA
17
17
#include < Eigen/Dense>
18
+ #include < unordered_map>
18
19
#include < vector>
19
20
#include " paddle/fluid/framework/op_registry.h"
20
21
#include " paddle/fluid/framework/threadpool.h"
@@ -311,17 +312,17 @@ struct SparseAdamFunctor<T, CPUAdam> {
311
312
T beta1_pow = *beta1_pow_;
312
313
T beta2_pow = *beta2_pow_;
313
314
lr *= sqrt (1 - beta2_pow) / (1 - beta1_pow);
314
- size_t row_count = numel / row_numel_;
315
+ int64_t row_count = static_cast < int64_t >( numel / row_numel_) ;
315
316
316
- for (size_t i = 0U , j = 0U ; i != row_count; ++i) {
317
+ for (int64_t i = 0 , j = 0 ; i != row_count; ++i) {
317
318
if (i == *(rows_ + j)) {
318
- for (size_t k = 0U ; k != row_numel_; ++k) {
319
+ for (int64_t k = 0 ; k != row_numel_; ++k) {
319
320
T g = grad_[j * row_numel_ + k];
320
321
adam_update (i * row_numel_ + k, g);
321
322
}
322
323
++j;
323
324
} else {
324
- for (size_t k = 0U ; k != row_numel_; ++k) {
325
+ for (int64_t k = 0 ; k != row_numel_; ++k) {
325
326
T mom1 = moment1_[i * row_numel_ + k];
326
327
T mom2 = moment2_[i * row_numel_ + k];
327
328
T p = param_[i * row_numel_ + k];
@@ -427,43 +428,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
427
428
}
428
429
}
429
430
430
- framework::SelectedRows cpu_grad_merge ;
431
+ framework::SelectedRows tmp_grad_merge ;
431
432
const framework::SelectedRows* grad_merge_ptr;
432
433
if (is_strict_sorted) {
433
434
grad_merge_ptr = &grad;
434
435
} else {
435
436
// merge duplicated rows if any.
436
437
// The rows of grad_merge have been sorted inside MergeAdd functor
437
- framework::SelectedRows* grad_merge_var;
438
438
scatter::MergeAdd<DeviceContext, T> merge_func;
439
- if (platform::is_cpu_place (ctx.GetPlace ())) {
440
- grad_merge_var = &cpu_grad_merge;
441
- } else {
442
- // FIXME(qiao): GPU also need to fix this
443
- grad_merge_var = const_cast <framework::Scope&>(ctx.scope ())
444
- .Var ()
445
- ->GetMutable <framework::SelectedRows>();
446
- }
447
439
merge_func (ctx.template device_context <DeviceContext>(), grad,
448
- grad_merge_var , true );
449
- grad_merge_ptr = grad_merge_var ;
440
+ &tmp_grad_merge , true );
441
+ grad_merge_ptr = &tmp_grad_merge ;
450
442
}
451
443
452
444
auto & grad_merge = *grad_merge_ptr;
453
445
auto & grad_tensor = grad_merge.value ();
454
446
const T* grad_data = grad_tensor.template data <T>();
455
- const int64_t * rows = nullptr ;
456
- // When compiled without CUDA, the CUDAData() interface should not be
457
- // provided.
458
- #if defined(PADDLE_WITH_CUDA)
459
- if (platform::is_gpu_place (ctx.GetPlace ())) {
460
- rows = grad_merge.rows ().CUDAData (ctx.GetPlace ());
461
- } else {
462
- #endif
463
- rows = grad_merge.rows ().data ();
464
- #if defined(PADDLE_WITH_CUDA)
465
- }
466
- #endif
447
+ const int64_t * rows = grad_merge.rows ().Data (ctx.GetPlace ());
467
448
auto row_numel = grad_tensor.numel () / grad_merge.rows ().size ();
468
449
469
450
if (platform::is_cpu_place (ctx.GetPlace ())) {
@@ -488,7 +469,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
488
469
}
489
470
}
490
471
#ifndef _WIN32
491
- else if (FLAGS_inner_op_parallelism > 1 &&
472
+ else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
492
473
min_row_size_to_use_multithread > 0 &&
493
474
param.dims ()[0 ] > min_row_size_to_use_multithread) {
494
475
VLOG (3 ) << " use multi thread, inner_op_parallelism="
@@ -516,11 +497,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
516
497
for (int i = 0 ; i < FLAGS_inner_op_parallelism; ++i) {
517
498
int64_t start = i * line_in_each_thread;
518
499
int64_t end = (i + 1 ) * line_in_each_thread;
519
- if (start >= param_row_count) {
500
+ if (start >= static_cast < int64_t >( param_row_count) ) {
520
501
break ;
521
502
}
522
- if (end > param_row_count) {
523
- end = param_row_count;
503
+ if (end > static_cast < int64_t >( param_row_count) ) {
504
+ end = static_cast < int64_t >( param_row_count) ;
524
505
}
525
506
fs.push_back (
526
507
framework::Async ([&functor, &row_id_to_grad_row_offset,
@@ -545,8 +526,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
545
526
}
546
527
for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
547
528
}
548
- #endif // !_WIN32
549
- else {
529
+ #endif // !_WIN32
530
+ else { // NOLINT
550
531
functor (param.numel ());
551
532
}
552
533
} else if (platform::is_gpu_place (ctx.GetPlace ())) {
0 commit comments