Skip to content

Commit 2ecc562

Browse files
panyx0718wanghaoshuang
authored andcommitted
small AverageOptimizer enhance. (#11761)
* small AverageOptimizer enhance. * clean * clean
1 parent 19e877f commit 2ecc562

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

paddle/fluid/operators/average_accumulates_op.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,28 @@ namespace operators {
1919

2020
template <>
2121
void GetAccumulators<paddle::platform::CPUDeviceContext>(
22-
const framework::ExecutionContext& ctx, int64_t* num_updates_,
23-
int64_t* num_accumulates_, int64_t* old_num_accumulates_) {
22+
const framework::ExecutionContext& ctx, int64_t* num_updates,
23+
int64_t* num_accumulates, int64_t* old_num_accumulates) {
2424
auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates");
2525
auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates");
2626
auto* in_num_updates = ctx.Input<Tensor>("in_num_updates");
2727

28-
*old_num_accumulates_ = in_old_num_accumulates->data<int64_t>()[0];
29-
*num_accumulates_ = in_num_accumulates->data<int64_t>()[0];
30-
*num_updates_ = in_num_updates->data<int64_t>()[0];
28+
*old_num_accumulates = in_old_num_accumulates->data<int64_t>()[0];
29+
*num_accumulates = in_num_accumulates->data<int64_t>()[0];
30+
*num_updates = in_num_updates->data<int64_t>()[0];
3131
}
3232

3333
template <>
3434
void SetAccumulators<paddle::platform::CPUDeviceContext>(
35-
const framework::ExecutionContext& ctx, int64_t num_updates_,
36-
int64_t num_accumulates_, int64_t old_num_accumulates_) {
35+
const framework::ExecutionContext& ctx, int64_t num_updates,
36+
int64_t num_accumulates, int64_t old_num_accumulates) {
3737
auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates");
3838
auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates");
3939
auto* out_num_updates = ctx.Output<Tensor>("out_num_updates");
4040

41-
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates_;
42-
out_num_accumulates->data<int64_t>()[0] = num_accumulates_;
43-
out_num_updates->data<int64_t>()[0] = num_updates_;
41+
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates;
42+
out_num_accumulates->data<int64_t>()[0] = num_accumulates;
43+
out_num_updates->data<int64_t>()[0] = num_updates;
4444
}
4545

4646
class AverageAccumulatesOp : public framework::OperatorWithKernel {
@@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
177177

178178
AddComment(R"DOC(
179179
AverageAccumulates Operator.
180-
Accumulate the sum of parameter whtin sliding window. The size of sliding window is
180+
Accumulate the sum of parameter within sliding window. The size of sliding window is
181181
determined by 'average_window', 'max_average_window' and 'min_average_window'.
182182
Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'.
183183
'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'.

paddle/fluid/operators/average_accumulates_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
5454
float average_window = ctx.Attr<float>("average_window");
5555
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
5656
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
57-
min_average_window =
58-
std::min<int64_t>(min_average_window, max_average_window);
57+
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
58+
"min_average_window shouldn't be larger than "
59+
"max_average_window");
5960

6061
// Get inputs
6162
auto* param = ctx.Input<Tensor>("param");

0 commit comments

Comments
 (0)