@@ -15,125 +15,7 @@ limitations under the License. */
15
15
#include " paddle/fluid/framework/op_registry.h"
16
16
#include " paddle/fluid/operators/momentum_op.h"
17
17
18
- namespace paddle {
19
- namespace operators {
20
-
21
- template <typename T>
22
- __global__ void MomentumKernel (const T* p, const T* g, const T* v,
23
- const T* learning_rate, const T mu,
24
- const int64_t num, bool use_nesterov, T* p_out,
25
- T* v_out) {
26
- T lr = learning_rate[0 ];
27
- if (use_nesterov) {
28
- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < num;
29
- i += blockDim .x * gridDim .x ) {
30
- T g_val = g[i];
31
- T v_new = v[i] * mu + g_val;
32
- v_out[i] = v_new;
33
- p_out[i] = p[i] - (g_val + v_new * mu) * lr;
34
- }
35
- } else {
36
- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < num;
37
- i += blockDim .x * gridDim .x ) {
38
- T v_new = v[i] * mu + g[i];
39
- v_out[i] = v_new;
40
- p_out[i] = p[i] - lr * v_new;
41
- }
42
- }
43
- }
44
-
45
- template <typename T>
46
- __global__ void SparseMomentumKernel (const T* p, const T* g, const T* v,
47
- const T* lr, const T mu,
48
- const int64_t * grad_rows,
49
- const size_t grad_row_numel,
50
- const size_t grad_row_size,
51
- const T use_nesterov, T* p_out, T* v_out) {
52
- for (int i = blockIdx .x ; i < grad_row_size; i += gridDim .x ) {
53
- for (int j = threadIdx .x ; j < grad_row_numel; j += blockDim .x ) {
54
- size_t p_i = grad_rows[i] * grad_row_numel + j;
55
- size_t g_i = i * grad_row_numel + j;
56
- v_out[g_i] = v[g_i] * mu + g[g_i];
57
- if (use_nesterov) {
58
- p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0 ];
59
- } else {
60
- p_out[p_i] = p[p_i] - v_out[g_i] * lr[0 ];
61
- }
62
- }
63
- }
64
- }
65
-
66
- template <typename T>
67
- class MomentumOpCUDAKernel : public framework ::OpKernel<T> {
68
- public:
69
- void Compute (const framework::ExecutionContext& ctx) const override {
70
- T mu = static_cast <T>(ctx.Attr <float >(" mu" ));
71
- bool use_nesterov = ctx.Attr <bool >(" use_nesterov" );
72
-
73
- auto learning_rate = ctx.Input <framework::Tensor>(" LearningRate" );
74
- auto param = ctx.Input <framework::Tensor>(" Param" );
75
- auto param_out = ctx.Output <framework::Tensor>(" ParamOut" );
76
- auto * velocity_var = ctx.InputVar (" Velocity" );
77
- auto * grad_var = ctx.InputVar (" Grad" );
78
-
79
- if (grad_var->IsType <framework::LoDTensor>()) {
80
- PADDLE_ENFORCE (velocity_var->IsType <framework::LoDTensor>(),
81
- " Unmatched Type of Param and Grad" );
82
- auto velocity = ctx.Input <framework::Tensor>(" Velocity" );
83
- auto grad = ctx.Input <framework::Tensor>(" Grad" );
84
- auto velocity_out = ctx.Output <framework::Tensor>(" VelocityOut" );
85
- T* p_out = param_out->mutable_data <T>(ctx.GetPlace ());
86
- T* v_out = velocity_out->mutable_data <T>(ctx.GetPlace ());
87
- auto * p = param->data <T>();
88
- auto * v = velocity->data <T>();
89
- auto * g = grad->data <T>();
90
- auto * lr = learning_rate->data <T>();
91
-
92
- const int kThreadPerBlock = 256 ;
93
- int grid = (param->numel () + kThreadPerBlock - 1 ) / kThreadPerBlock ;
94
- MomentumKernel<
95
- T><<<grid, kThreadPerBlock , 0 , ctx.cuda_device_context().stream()>>> (
96
- p, g, v, lr, mu, param->numel (), use_nesterov, p_out, v_out);
97
- } else if (grad_var->IsType <framework::SelectedRows>()) {
98
- // sparse update embedding with selectedrows
99
- PADDLE_ENFORCE (velocity_var->IsType <framework::SelectedRows>(),
100
- " Unmatched Type of Param and Grad" );
101
- auto velocity = ctx.Input <framework::SelectedRows>(" Velocity" );
102
- auto grad = ctx.Input <framework::SelectedRows>(" Grad" );
103
- auto velocity_out = ctx.Output <framework::SelectedRows>(" VelocityOut" );
104
-
105
- // sparse update maybe empty.
106
- if (grad->rows ().size () == 0 ) {
107
- return ;
108
- }
109
- PADDLE_ENFORCE (grad->height () == velocity->height (),
110
- " Unmatched gradient and velocity." );
111
- auto * p_out = param_out->mutable_data <T>(ctx.GetPlace ());
112
- auto * v_out =
113
- velocity_out->mutable_value ()->mutable_data <T>(ctx.GetPlace ());
114
- auto * lr = learning_rate->data <T>();
115
- auto * p = param->data <T>();
116
- auto * g = grad->value ().data <T>();
117
- auto * v = velocity->value ().data <T>();
118
- size_t grad_row_numel = grad->value ().numel () / grad->rows ().size ();
119
- size_t grad_row_size = grad->rows ().size ();
120
- framework::Vector<int64_t > rows (grad->rows ());
121
-
122
- const int kThreadPerBlock = 256 ;
123
- int grid = (param->numel () + kThreadPerBlock - 1 ) / kThreadPerBlock ;
124
- SparseMomentumKernel<
125
- T><<<grid, kThreadPerBlock , 0 , ctx.cuda_device_context().stream()>>> (
126
- p, g, v, lr, mu, rows.CUDAData (ctx.GetPlace ()), grad_row_numel,
127
- grad->rows ().size (), use_nesterov, p_out, v_out);
128
- } else {
129
- PADDLE_THROW (" Unsupported Variable Type of Grad" );
130
- }
131
- }
132
- };
133
-
134
- } // namespace operators
135
- } // namespace paddle
136
-
137
18
namespace ops = paddle::operators;
138
- REGISTER_OP_CUDA_KERNEL (momentum, ops::MomentumOpCUDAKernel<float >,
139
- ops::MomentumOpCUDAKernel<double >);
19
+ REGISTER_OP_CUDA_KERNEL (
20
+ momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float >,
21
+ ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double >);
0 commit comments