|
12 | 12 | See the License for the specific language governing permissions and
|
13 | 13 | limitations under the License. */
|
14 | 14 |
|
15 |
| -#define EIGEN_USE_GPU |
16 |
| -#include "paddle/operators/momentum_op.h" |
| 15 | +#include "paddle/framework/op_registry.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +__global__ void MomentumKernel(const T* p, const T* g, const T* v, |
| 22 | + const T* learning_rate, const T mu, |
| 23 | + const int64_t num, bool use_nesterov, T* p_out, |
| 24 | + T* v_out) { |
| 25 | + T lr = learning_rate[0]; |
| 26 | + if (use_nesterov) { |
| 27 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; |
| 28 | + i += blockDim.x * gridDim.x) { |
| 29 | + T g_val = g[i]; |
| 30 | + T v_new = v[i] * mu + g_val; |
| 31 | + v_out[i] = v_new; |
| 32 | + p_out[i] = p[i] - (g_val - v_new * mu) * lr; |
| 33 | + } |
| 34 | + } else { |
| 35 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; |
| 36 | + i += blockDim.x * gridDim.x) { |
| 37 | + T v_new = v[i] * mu + g[i]; |
| 38 | + v_out[i] = v_new; |
| 39 | + p_out[i] = p[i] - lr * v_new; |
| 40 | + } |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +template <typename T> |
| 45 | +class MomentumOpCUDAKernel : public framework::OpKernel<T> { |
| 46 | + public: |
| 47 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 48 | + auto param_out = ctx.Output<framework::Tensor>("ParamOut"); |
| 49 | + auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut"); |
| 50 | + auto param = ctx.Input<framework::Tensor>("Param"); |
| 51 | + auto velocity = ctx.Input<framework::Tensor>("Velocity"); |
| 52 | + auto grad = ctx.Input<framework::Tensor>("Grad"); |
| 53 | + auto learning_rate = ctx.Input<framework::Tensor>("LearningRate"); |
| 54 | + |
| 55 | + T* p_out = param_out->mutable_data<T>(ctx.GetPlace()); |
| 56 | + T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace()); |
| 57 | + |
| 58 | + T mu = static_cast<T>(ctx.Attr<float>("mu")); |
| 59 | + bool use_nesterov = ctx.Attr<bool>("use_nesterov"); |
| 60 | + |
| 61 | + auto* p = param->data<T>(); |
| 62 | + auto* v = velocity->data<T>(); |
| 63 | + auto* g = grad->data<T>(); |
| 64 | + auto* lr = learning_rate->data<T>(); |
| 65 | + |
| 66 | + int block = 512; |
| 67 | + int grid = (param->numel() + block - 1) / block; |
| 68 | + MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( |
| 69 | + p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out); |
| 70 | + } |
| 71 | +}; |
| 72 | + |
| 73 | +} // namespace operators |
| 74 | +} // namespace paddle |
17 | 75 |
|
18 | 76 | namespace ops = paddle::operators;
|
19 |
| -REGISTER_OP_GPU_KERNEL( |
20 |
| - momentum, ops::MomentumOpKernel<paddle::platform::GPUPlace, float>); |
| 77 | +REGISTER_OP_GPU_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>, |
| 78 | + ops::MomentumOpCUDAKernel<double>); |
0 commit comments