Skip to content

Commit d432b10

Browse files
committed
Update cuda kernel and doc.
1 parent e03b574 commit d432b10

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

paddle/operators/momentum_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
7171
"(Tensor, default Tensor<float>) "
7272
"Input learning rate");
7373

74-
AddOutput("ParamOut", "(Tensor) Output updated parameter");
75-
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
74+
AddOutput("ParamOut",
75+
"(Tensor) This output is updated parameter. "
76+
"It shared memory with Input(Param).");
77+
AddOutput("VelocityOut",
78+
"(Tensor) This output is updated velocity. "
79+
"It shared memory with Input(Velocity).");
7680

7781
AddAttr<float>("mu", "(float) Momentum coefficient");
7882
AddAttr<bool>("use_nesterov",

paddle/operators/momentum_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
2929
T g_val = g[i];
3030
T v_new = v[i] * mu + g_val;
3131
v_out[i] = v_new;
32-
p_out[i] = p[i] - g_val * lr + v_new * mu * lr;
32+
p_out[i] = p[i] - (g_val - v_new * mu) * lr;
3333
}
3434
} else {
3535
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;

0 commit comments

Comments
 (0)