Skip to content

Commit 6429d2a

Browse files
authored
Merge pull request #16188 from sneaxiy/fix_const_cast
Remove const_cast in optimizers
2 parents e818fa1 + f0d108f commit 6429d2a

File tree

3 files changed

+25
-61
lines changed

3 files changed

+25
-61
lines changed

paddle/fluid/operators/optimizers/adam_op.h

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616
#include <math.h> // for sqrt in CPU and CUDA
1717
#include <Eigen/Dense>
18+
#include <unordered_map>
1819
#include <vector>
1920
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/framework/threadpool.h"
@@ -311,17 +312,17 @@ struct SparseAdamFunctor<T, CPUAdam> {
311312
T beta1_pow = *beta1_pow_;
312313
T beta2_pow = *beta2_pow_;
313314
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
314-
size_t row_count = numel / row_numel_;
315+
int64_t row_count = static_cast<int64_t>(numel / row_numel_);
315316

316-
for (size_t i = 0U, j = 0U; i != row_count; ++i) {
317+
for (int64_t i = 0, j = 0; i != row_count; ++i) {
317318
if (i == *(rows_ + j)) {
318-
for (size_t k = 0U; k != row_numel_; ++k) {
319+
for (int64_t k = 0; k != row_numel_; ++k) {
319320
T g = grad_[j * row_numel_ + k];
320321
adam_update(i * row_numel_ + k, g);
321322
}
322323
++j;
323324
} else {
324-
for (size_t k = 0U; k != row_numel_; ++k) {
325+
for (int64_t k = 0; k != row_numel_; ++k) {
325326
T mom1 = moment1_[i * row_numel_ + k];
326327
T mom2 = moment2_[i * row_numel_ + k];
327328
T p = param_[i * row_numel_ + k];
@@ -427,43 +428,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
427428
}
428429
}
429430

430-
framework::SelectedRows cpu_grad_merge;
431+
framework::SelectedRows tmp_grad_merge;
431432
const framework::SelectedRows* grad_merge_ptr;
432433
if (is_strict_sorted) {
433434
grad_merge_ptr = &grad;
434435
} else {
435436
// merge duplicated rows if any.
436437
// The rows of grad_merge have been sorted inside MergeAdd functor
437-
framework::SelectedRows* grad_merge_var;
438438
scatter::MergeAdd<DeviceContext, T> merge_func;
439-
if (platform::is_cpu_place(ctx.GetPlace())) {
440-
grad_merge_var = &cpu_grad_merge;
441-
} else {
442-
// FIXME(qiao): GPU also need to fix this
443-
grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
444-
.Var()
445-
->GetMutable<framework::SelectedRows>();
446-
}
447439
merge_func(ctx.template device_context<DeviceContext>(), grad,
448-
grad_merge_var, true);
449-
grad_merge_ptr = grad_merge_var;
440+
&tmp_grad_merge, true);
441+
grad_merge_ptr = &tmp_grad_merge;
450442
}
451443

452444
auto& grad_merge = *grad_merge_ptr;
453445
auto& grad_tensor = grad_merge.value();
454446
const T* grad_data = grad_tensor.template data<T>();
455-
const int64_t* rows = nullptr;
456-
// When compiled without CUDA, the CUDAData() interface should not be
457-
// provided.
458-
#if defined(PADDLE_WITH_CUDA)
459-
if (platform::is_gpu_place(ctx.GetPlace())) {
460-
rows = grad_merge.rows().CUDAData(ctx.GetPlace());
461-
} else {
462-
#endif
463-
rows = grad_merge.rows().data();
464-
#if defined(PADDLE_WITH_CUDA)
465-
}
466-
#endif
447+
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
467448
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
468449

469450
if (platform::is_cpu_place(ctx.GetPlace())) {
@@ -488,7 +469,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
488469
}
489470
}
490471
#ifndef _WIN32
491-
else if (FLAGS_inner_op_parallelism > 1 &&
472+
else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
492473
min_row_size_to_use_multithread > 0 &&
493474
param.dims()[0] > min_row_size_to_use_multithread) {
494475
VLOG(3) << "use multi thread, inner_op_parallelism="
@@ -516,11 +497,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
516497
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
517498
int64_t start = i * line_in_each_thread;
518499
int64_t end = (i + 1) * line_in_each_thread;
519-
if (start >= param_row_count) {
500+
if (start >= static_cast<int64_t>(param_row_count)) {
520501
break;
521502
}
522-
if (end > param_row_count) {
523-
end = param_row_count;
503+
if (end > static_cast<int64_t>(param_row_count)) {
504+
end = static_cast<int64_t>(param_row_count);
524505
}
525506
fs.push_back(
526507
framework::Async([&functor, &row_id_to_grad_row_offset,
@@ -545,8 +526,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
545526
}
546527
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
547528
}
548-
#endif // !_WIN32
549-
else {
529+
#endif // !_WIN32
530+
else { // NOLINT
550531
functor(param.numel());
551532
}
552533
} else if (platform::is_gpu_place(ctx.GetPlace())) {

paddle/fluid/operators/optimizers/momentum_op.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <memory>
1617
#include <string>
1718
#include "paddle/fluid/framework/eigen.h"
1819
#include "paddle/fluid/framework/op_registry.h"
@@ -69,6 +70,7 @@ class MomentumOp : public framework::OperatorWithKernel {
6970
ctx->SetOutputDim("ParamOut", param_dim);
7071
ctx->SetOutputDim("VelocityOut", param_dim);
7172
}
73+
7274
framework::OpKernelType GetExpectedKernelType(
7375
const framework::ExecutionContext& ctx) const override {
7476
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
@@ -351,23 +353,14 @@ class MomentumOpKernel : public framework::OpKernel<T> {
351353
VLOG(3) << "Grad SelectedRows contains no data!";
352354
return;
353355
}
354-
auto* merged_grad = const_cast<framework::Scope&>(ctx.scope())
355-
.Var()
356-
->GetMutable<framework::SelectedRows>();
356+
357+
framework::SelectedRows tmp_merged_grad;
358+
framework::SelectedRows* merged_grad = &tmp_merged_grad;
357359
math::scatter::MergeAdd<DeviceContext, T> merge_func;
358360
merge_func(ctx.template device_context<DeviceContext>(), *grad,
359361
merged_grad);
360362

361-
const int64_t* rows = nullptr;
362-
#ifdef PADDLE_WITH_CUDA
363-
if (platform::is_gpu_place(ctx.GetPlace())) {
364-
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
365-
} else {
366-
#endif
367-
rows = merged_grad->rows().data();
368-
#ifdef PADDLE_WITH_CUDA
369-
}
370-
#endif
363+
const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
371364
int64_t row_numel =
372365
merged_grad->value().numel() / merged_grad->rows().size();
373366
platform::ForRange<DeviceContext> for_range(

paddle/fluid/operators/optimizers/rmsprop_op.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,24 +216,14 @@ class RmspropOpKernel : public framework::OpKernel<T> {
216216
}
217217
} else if (grad_var->IsType<framework::SelectedRows>()) {
218218
auto &grad = grad_var->Get<framework::SelectedRows>();
219-
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope())
220-
.Var()
221-
->GetMutable<framework::SelectedRows>();
222-
219+
framework::SelectedRows tmp_merged_grad;
220+
framework::SelectedRows *merged_grad = &tmp_merged_grad;
223221
math::scatter::MergeAdd<DeviceContext, T> merge_func;
224222
merge_func(dev_ctx, grad, merged_grad);
225223

226224
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
227-
const int64_t *rows;
228-
#ifdef PADDLE_WITH_CUDA
229-
if (platform::is_gpu_place(ctx.GetPlace())) {
230-
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
231-
} else {
232-
#endif
233-
rows = merged_grad->rows().data();
234-
#ifdef PADDLE_WITH_CUDA
235-
}
236-
#endif
225+
const int64_t *rows = merged_grad->rows().Data(ctx.GetPlace());
226+
237227
auto &merged_tensor = merged_grad->value();
238228
int64_t row_count = merged_grad->rows().size();
239229
int64_t row_numel = merged_tensor.numel() / row_count;

0 commit comments

Comments
 (0)