1
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2
-
3
2
Licensed under the Apache License, Version 2.0 (the "License");
4
3
you may not use this file except in compliance with the License.
5
4
You may obtain a copy of the License at
6
-
7
5
http://www.apache.org/licenses/LICENSE-2.0
8
-
9
6
Unless required by applicable law or agreed to in writing, software
10
7
distributed under the License is distributed on an "AS IS" BASIS,
11
8
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
11
15
12
#include " paddle/framework/eigen.h"
16
13
#include " paddle/framework/op_registry.h"
14
+ #include " paddle/operators/lookup_table_op.h"
17
15
#include " paddle/platform/assert.h"
18
16
#include " paddle/platform/cuda_helper.h"
19
17
20
18
namespace paddle {
21
19
namespace operators {
22
20
23
- using Tensor = framework::Tensor;
24
-
25
21
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
26
- __global__ void LookupTable (T* output, const T* table, const int32_t * ids,
27
- const int N, const int K, const int D) {
22
+ __global__ void LookupTable (T* output, const T* table, const int64_t * ids,
23
+ const int64_t N, const int64_t K, const int64_t D) {
28
24
int idx = threadIdx .x ;
29
25
int idy = blockIdx .x + threadIdx .y * GridDimX;
30
26
31
27
while (idy < K) {
32
- int id = ids[idy];
28
+ int64_t id = ids[idy];
33
29
PADDLE_ASSERT (id >= 0 );
34
30
PADDLE_ASSERT (id < N);
35
31
T* out = output + idy * D;
@@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
42
38
}
43
39
44
40
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
45
- __global__ void LookupTableGrad (T* table, const T* output, const int32_t * ids,
46
- const int N, const int K, const int D) {
41
+ __global__ void LookupTableGrad (T* table, const T* output, const int64_t * ids,
42
+ const int64_t N, const int64_t K,
43
+ const int64_t D) {
47
44
int idx = threadIdx .x ;
48
45
int idy = blockIdx .x + threadIdx .y * GridDimX;
49
46
@@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
71
68
size_t N = table_t ->dims ()[0 ];
72
69
size_t D = table_t ->dims ()[1 ];
73
70
size_t K = ids_t ->numel ();
74
- auto ids = ids_t ->data <int32_t >();
71
+ auto ids = ids_t ->data <int64_t >();
75
72
auto table = table_t ->data <T>();
76
73
auto output = output_t ->mutable_data <T>(context.GetPlace ());
77
74
@@ -88,34 +85,71 @@ template <typename T>
88
85
class LookupTableGradCUDAKernel : public framework ::OpKernel<T> {
89
86
public:
90
87
void Compute (const framework::ExecutionContext& context) const override {
91
- auto ids_t = context.Input <Tensor>(" Ids" );
92
- auto d_output_t = context.Input <Tensor>(framework::GradVarName (" Out" ));
93
- auto d_table_t = context.Output <Tensor>(framework::GradVarName (" W" ));
94
-
95
- int N = d_table_t ->dims ()[0 ];
96
- int D = d_table_t ->dims ()[1 ];
97
- int K = ids_t ->numel ();
98
- const int32_t * ids = ids_t ->data <int32_t >();
99
- const T* d_output = d_output_t ->data <T>();
100
- T* d_table = d_table_t ->mutable_data <T>(context.GetPlace ());
101
-
102
- auto t = framework::EigenVector<T>::Flatten (*d_table_t );
103
- t.device (context.GetEigenDevice <platform::GPUPlace>()) =
104
- t.constant (static_cast <T>(0 ));
105
-
106
- dim3 threads (128 , 8 );
107
- dim3 grids (8 , 1 );
108
- LookupTableGrad<T, 128 , 8 , 8 ><<<
109
- grids, threads, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
88
+ bool is_sparse = context.Attr <bool >(" is_sparse" );
89
+ if (is_sparse) {
90
+ auto * ids = context.Input <Tensor>(" Ids" );
91
+ auto * table = context.Input <Tensor>(" W" );
92
+ auto * d_output = context.Input <Tensor>(framework::GradVarName (" Out" ));
93
+ auto * d_table = context.Output <SelectedRows>(framework::GradVarName (" W" ));
94
+
95
+ auto * ids_data = ids->data <int64_t >();
96
+ auto ids_dim = ids->dims ();
97
+
98
+ auto stream = reinterpret_cast <const platform::CUDADeviceContext&>(
99
+ context.device_context ())
100
+ .stream ();
101
+ // copy GPU memory to CPU pinned memory
102
+ framework::Vector<int64_t > new_rows;
103
+ new_rows.resize (ids_dim[0 ]);
104
+ auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace ());
105
+
106
+ memory::Copy (platform::CPUPlace (), new_rows.data (), gpu_place, ids_data,
107
+ ids_dim[0 ] * sizeof (int64_t ), stream);
108
+
109
+ d_table->set_rows (new_rows);
110
+
111
+ auto * d_table_value = d_table->mutable_value ();
112
+ d_table_value->Resize ({ids_dim[0 ], table->dims ()[1 ]});
113
+ d_table_value->mutable_data <T>(context.GetPlace ());
114
+
115
+ auto * d_table_data = d_table_value->data <T>();
116
+ auto * d_output_data = d_output->data <T>();
117
+ PADDLE_ENFORCE_EQ (d_table_value->dims (), d_output->dims ());
118
+ memory::Copy (gpu_place, d_table_data, gpu_place, d_output_data,
119
+ d_output->numel (), stream);
120
+
121
+ } else {
122
+ auto ids_t = context.Input <Tensor>(" Ids" );
123
+ auto d_output_t = context.Input <Tensor>(framework::GradVarName (" Out" ));
124
+ auto d_table_t = context.Output <Tensor>(framework::GradVarName (" W" ));
125
+
126
+ int N = d_table_t ->dims ()[0 ];
127
+ int D = d_table_t ->dims ()[1 ];
128
+ int K = ids_t ->numel ();
129
+ const int64_t * ids = ids_t ->data <int64_t >();
130
+ const T* d_output = d_output_t ->data <T>();
131
+ T* d_table = d_table_t ->mutable_data <T>(context.GetPlace ());
132
+
133
+ auto t = framework::EigenVector<T>::Flatten (*d_table_t );
134
+ t.device (context.GetEigenDevice <platform::GPUPlace>()) =
135
+ t.constant (static_cast <T>(0 ));
136
+
137
+ dim3 threads (128 , 8 );
138
+ dim3 grids (8 , 1 );
139
+ LookupTableGrad<T, 128 , 8 ,
140
+ 8 ><<<grids, threads, 0 ,
141
+ reinterpret_cast <const platform::CUDADeviceContext&>(
110
142
context.device_context())
111
143
.stream()>>> (d_table, d_output, ids, N, K, D);
144
+ }
112
145
}
113
146
};
114
147
115
148
} // namespace operators
116
149
} // namespace paddle
117
150
118
151
namespace ops = paddle::operators;
119
- REGISTER_OP_GPU_KERNEL (lookup_table, ops::LookupTableCUDAKernel<float >);
120
- REGISTER_OP_GPU_KERNEL (lookup_table_grad,
121
- ops::LookupTableGradCUDAKernel<float >);
152
+ REGISTER_OP_GPU_KERNEL (lookup_table, ops::LookupTableCUDAKernel<float >,
153
+ ops::LookupTableCUDAKernel<double >);
154
+ REGISTER_OP_GPU_KERNEL (lookup_table_grad, ops::LookupTableGradCUDAKernel<float >,
155
+ ops::LookupTableGradCUDAKernel<double >);
0 commit comments