@@ -27,6 +27,12 @@ namespace {
2727
2828template <typename TKey, typename TValue>
2929struct GroupEmbeddingBackWardArgs {
30+ GroupEmbeddingBackWardArgs () = default ;
31+ GroupEmbeddingBackWardArgs (TValue *grads, TKey *sp_values,
32+ TValue *emb_variable, TValue *grads_output,
33+ int *offset_indices, int nnz)
34+ : grads_(grads), sp_values_(sp_values),emb_variable_(emb_variable),
35+ grads_output_ (grads_output), offset_indices_(offset_indices), nnz_(nnz) {}
3036 TValue *grads_;
3137 TKey *sp_values_;
3238 TValue *emb_variable_;
@@ -59,26 +65,29 @@ __global__ void ComputeEVGradFn(
5965 feature_num = args[idx].offset_indices_ [bid + 1 ] - value_offset;
6066 }
6167
62- float grad = args[idx].grads_ [bid * dimension + tid];
63- grad = CombineGrad<combiner>(grad, feature_num);
64-
65- for (int j = 0 ; j < feature_num; ++j) {
66- float grad_i = grad;
67- int feature_offset = (value_offset + j) * dimension;
68- if (max_norm > 0 .0f ) {
69- float emb_element = 0 .0f ; // TODO: hujunqi get emb_weight
70- if (tid == 0 ) {
71- l2_sum = 0 .0f ;
72- }
73- tile.shfl (l2_sum, 0 );
74- atomicAdd (&l2_sum, emb_element * emb_element);
75- tile.sync ();
76- float l2_norm = sqrtf (l2_sum);
77- if (l2_norm > max_norm) {
78- grad_i *= max_norm / l2_norm;
68+ if (feature_num > 0 ) {
69+ float grad = args[idx].grads_ [bid * dimension + tid];
70+ grad = CombineGrad<combiner>(grad, feature_num);
71+
72+ for (int j = 0 ; j < feature_num; ++j) {
73+ float grad_i = grad;
74+ int feature_offset = (value_offset + j) * dimension;
75+ if (max_norm > 0 .0f ) {
76+ float emb_element = 0 .0f ; // TODO(junqihu): get emb_weight
77+ if (tid == 0 ) {
78+ l2_sum = 0 .0f ;
79+ }
80+ tile.shfl (l2_sum, 0 );
81+ atomicAdd (&l2_sum, emb_element * emb_element);
82+ tile.sync ();
83+ float l2_norm = sqrtf (l2_sum);
84+ if (l2_norm > max_norm) {
85+ grad_i *= max_norm / l2_norm;
86+ }
7987 }
88+ args[idx].grads_output_ [(value_offset + j) * dimension + tid] =
89+ grad_i;
8090 }
81- args[idx].grads_output_ [(value_offset + j) * dimension + tid] = grad_i;
8291 }
8392 }
8493 }
@@ -105,31 +114,29 @@ __global__ void ComputeSparseGradFn(
105114 } else {
106115 feature_num = args[idx].offset_indices_ [bid + 1 ] - value_offset;
107116 }
108- float grad = args[idx].grads_ [bid * dimension + tid];
109- // printf("feature_num is %d , grad is %lld , bid is %d , tid is %d \n",
110- // feature_num, grad, blockIdx.x, threadIdx.x);
111- grad = CombineGrad<combiner>(grad, feature_num);
112- for (int i = 0 ; i < feature_num; i++) {
113- float grad_i = grad;
114- if (max_norm > 0 .0f ) {
115- int64_t indices = int (args[idx].sp_values_ [value_offset + i]);
116- float emb_element = args[idx].emb_variable_ [indices * dimension + tid];
117- // if (FastBoundsCheck(indices, args[idx].emb_row_size_)) {
118- // emb_element =
119- // args[idx].emb_variable_[indices * dimension + tid];
120- // }
121- if (tid == 0 ) {
122- l2_sum = 0 .0f ;
123- }
124- tile.shfl (l2_sum, 0 );
125- atomicAdd (&l2_sum, emb_element * emb_element);
126- tile.sync ();
127- float l2_norm = sqrtf (l2_sum);
128- if (l2_norm > max_norm) {
129- grad_i *= max_norm / l2_norm;
117+
118+ if (feature_num > 0 ) {
119+ float grad = args[idx].grads_ [bid * dimension + tid];
120+ grad = CombineGrad<combiner>(grad, feature_num);
121+ for (int i = 0 ; i < feature_num; i++) {
122+ float grad_i = grad;
123+ if (max_norm > 0 .0f ) {
124+ int64_t indices = int (args[idx].sp_values_ [value_offset + i]);
125+ float emb_element =
126+ args[idx].emb_variable_ [indices * dimension + tid];
127+ if (tid == 0 ) {
128+ l2_sum = 0 .0f ;
129+ }
130+ tile.shfl (l2_sum, 0 );
131+ atomicAdd (&l2_sum, emb_element * emb_element);
132+ tile.sync ();
133+ float l2_norm = sqrtf (l2_sum);
134+ if (l2_norm > max_norm) {
135+ grad_i *= max_norm / l2_norm;
136+ }
130137 }
138+ args[idx].grads_output_ [(value_offset + i) * dimension + tid] = grad_i;
131139 }
132- args[idx].grads_output_ [(value_offset + i) * dimension + tid] = grad_i;
133140 }
134141 }
135142}
@@ -156,26 +163,29 @@ __global__ void NormalComputeEVGradFn(
156163 feature_num = args[idx].offset_indices_ [bid + 1 ] - value_offset;
157164 }
158165
159- float grad = args[idx].grads_ [bid * dimension + tid];
160- grad = CombineGrad<combiner>(grad, feature_num);
161-
162- for (int j = 0 ; j < feature_num; ++j) {
163- float grad_i = grad;
164- int feature_offset = (value_offset + j) * dimension;
165- if (max_norm > 0 .0f ) {
166- float emb_element = 0 .0f ; // TODO: hujunqi get emb_weight
167- if (tid == 0 ) {
168- l2_sum[0 ] = 0 .0f ;
169- }
170- __syncthreads ();
171- atomicAdd (l2_sum, emb_element * emb_element);
172- __syncthreads ();
173- float l2_norm = sqrtf (l2_sum[0 ]);
174- if (l2_norm > max_norm) {
175- grad_i *= max_norm / l2_norm;
166+ if (feature_num > 0 ) {
167+ float grad = args[idx].grads_ [bid * dimension + tid];
168+ grad = CombineGrad<combiner>(grad, feature_num);
169+
170+ for (int j = 0 ; j < feature_num; ++j) {
171+ float grad_i = grad;
172+ int feature_offset = (value_offset + j) * dimension;
173+ if (max_norm > 0 .0f ) {
174+ float emb_element = 0 .0f ; // TODO(junqihu): get emb_weight
175+ if (tid == 0 ) {
176+ l2_sum[0 ] = 0 .0f ;
177+ }
178+ __syncthreads ();
179+ atomicAdd (l2_sum, emb_element * emb_element);
180+ __syncthreads ();
181+ float l2_norm = sqrtf (l2_sum[0 ]);
182+ if (l2_norm > max_norm) {
183+ grad_i *= max_norm / l2_norm;
184+ }
176185 }
186+ args[idx].grads_output_ [(value_offset + j) * dimension + tid] =
187+ grad_i;
177188 }
178- args[idx].grads_output_ [(value_offset + j) * dimension + tid] = grad_i;
179189 }
180190 }
181191 }
@@ -201,27 +211,29 @@ __global__ void NormalComputeSparseGradFn(
201211 } else {
202212 feature_num = args[idx].offset_indices_ [bid + 1 ] - value_offset;
203213 }
204- float grad = args[idx].grads_ [bid * dimension + tid];
205- // printf("feature_num is %d , grad is %lld , bid is %d , tid is %d \n",
206- // feature_num, grad, blockIdx.x, threadIdx.x);
207- grad = CombineGrad<combiner>(grad, feature_num);
208- for (int i = 0 ; i < feature_num; i++) {
209- float grad_i = grad;
210- if (max_norm > 0 .0f ) {
211- int64_t indices = int (args[idx].sp_values_ [value_offset + i]);
212- float emb_element = args[idx].emb_variable_ [indices * dimension + tid];
213- if (tid == 0 ) {
214- l2_sum[0 ] = 0 .0f ;
215- }
216- __syncthreads ();
217- atomicAdd (l2_sum, emb_element * emb_element);
218- __syncthreads ();
219- float l2_norm = sqrtf (l2_sum[0 ]);
220- if (l2_norm > max_norm) {
221- grad_i *= max_norm / l2_norm;
214+
215+ if (feature_num > 0 ) {
216+ float grad = args[idx].grads_ [bid * dimension + tid];
217+ grad = CombineGrad<combiner>(grad, feature_num);
218+ for (int i = 0 ; i < feature_num; i++) {
219+ float grad_i = grad;
220+ if (max_norm > 0 .0f ) {
221+ int64_t indices = int (args[idx].sp_values_ [value_offset + i]);
222+ float emb_element =
223+ args[idx].emb_variable_ [indices * dimension + tid];
224+ if (tid == 0 ) {
225+ l2_sum[0 ] = 0 .0f ;
226+ }
227+ __syncthreads ();
228+ atomicAdd (l2_sum, emb_element * emb_element);
229+ __syncthreads ();
230+ float l2_norm = sqrtf (l2_sum[0 ]);
231+ if (l2_norm > max_norm) {
232+ grad_i *= max_norm / l2_norm;
233+ }
222234 }
235+ args[idx].grads_output_ [(value_offset + i) * dimension + tid] = grad_i;
223236 }
224- args[idx].grads_output_ [(value_offset + i) * dimension + tid] = grad_i;
225237 }
226238 }
227239}
@@ -231,52 +243,54 @@ __global__ void NormalComputeSparseGradFn(
231243template <typename TKey, typename TValue>
232244class GroupEmbeddingLookupBackWard {
233245 public:
234- void initialize (int dimension, int num_lookups, float max_norm) {
235- CK_CUDA_THROW_ (cudaMalloc (
236- &d_args_,
237- sizeof (GroupEmbeddingBackWardArgs<TKey, TValue>) * num_lookups));
238- args_.resize (num_lookups);
246+ explicit GroupEmbeddingLookupBackWard (int dimension, int num_lookups,
247+ float max_norm,
248+ Allocator *gpu_allocator = nullptr )
249+ : alloc_(gpu_allocator) {
250+ d_args_ =
251+ TypedAllocator::Allocate<GroupEmbeddingBackWardArgs<TKey, TValue>>(
252+ gpu_allocator, num_lookups, AllocationAttributes ());
253+ h_args_.reserve (num_lookups);
239254 max_norm_ = max_norm;
240255 nums_ = num_lookups;
241256 dimension_ = dimension;
242257 }
243258
244- void set (int idx, TValue *grads, TValue *grads_output, int *offset_indices,
245- TKey *sp_values, TValue *emb_variable, int nnz) {
246- args_[idx].grads_ = grads;
247- args_[idx].grads_output_ = grads_output;
248- args_[idx].offset_indices_ = offset_indices;
249- args_[idx].sp_values_ = sp_values;
250- args_[idx].emb_variable_ = emb_variable;
251- args_[idx].nnz_ = nnz;
259+ void set (GroupEmbeddingBackWardArgs<TKey, TValue> &arg) {
260+ h_args_.emplace_back (arg);
252261 }
253262
254263 ~GroupEmbeddingLookupBackWard () {
255- if (d_args_) {
256- CK_CUDA_THROW_ (cudaFree (d_args_));
257- }
264+ TypedAllocator::Deallocate (alloc_, d_args_, nums_);
258265 }
259266
260267 template <typename GradFn>
261- void Backward (GradFn fn, int batch_size, int tile_size, cudaStream_t stream) {
268+ inline void Backward (GradFn fn, int batch_size, int tile_size,
269+ cudaStream_t stream) {
262270 CK_CUDA_THROW_ (cudaMemcpyAsync (
263- d_args_, args_ .data (),
264- args_ .size () * sizeof (GroupEmbeddingBackWardArgs<TKey, TValue>),
271+ d_args_, h_args_ .data (),
272+ h_args_ .size () * sizeof (GroupEmbeddingBackWardArgs<TKey, TValue>),
265273 cudaMemcpyHostToDevice, stream));
266274
267275 {
268- const int block_size = (batch_size - 1 ) / 64 * tile_size + 1 ;
276+ if (tile_size <= 32 ) {
277+ const int block_size = batch_size / 64 * tile_size + 1 ;
269278
270- fn<<<block_size, 64 , 0 , stream>>>(batch_size, max_norm_, nums_,
271- dimension_, d_args_);
279+ fn<<<block_size, 64 , 0 , stream>>>(batch_size, max_norm_, nums_,
280+ dimension_, d_args_);
281+ } else {
282+ fn<<<batch_size, tile_size, 0 , stream>>>(batch_size, max_norm_, nums_,
283+ dimension_, d_args_);
284+ }
272285 }
273286
274287 CK_CUDA_THROW_ (cudaGetLastError ());
275288 }
276289
277290 protected:
278- std::vector<GroupEmbeddingBackWardArgs<TKey, TValue>> args_ ;
291+ std::vector<GroupEmbeddingBackWardArgs<TKey, TValue>> h_args_ ;
279292 GroupEmbeddingBackWardArgs<TKey, TValue> *d_args_;
293+ Allocator *alloc_;
280294 float max_norm_;
281295 int nums_;
282296 int dimension_;
@@ -290,56 +304,55 @@ class GroupLookupBackWardBaseOp : public OpKernel {
290304 OP_REQUIRES_OK (c, c->GetAttr (" max_norm" , &max_norm_));
291305 OP_REQUIRES_OK (c, c->GetAttr (" num_lookups" , &num_lookups_));
292306 OP_REQUIRES_OK (c, c->GetAttr (" dimension" , &dimension_));
293- lookuper_.initialize (dimension_, num_lookups_, max_norm_);
294307 }
295308
296309 template <bool Isev = false , Combiner combiner>
297- inline void compute (const int batch_size, cudaStream_t stream) {
310+ inline void compute (GroupEmbeddingLookupBackWard<TKey, TValue> &lookuper,
311+ const int batch_size, cudaStream_t stream) {
298312 if (Isev) {
299313 if (dimension_ <= 2 ) {
300- lookuper_ .Backward (ComputeEVGradFn<TKey, TValue, combiner, 2 >,
301- batch_size, 2 , stream);
314+ lookuper .Backward (ComputeEVGradFn<TKey, TValue, combiner, 2 >,
315+ batch_size, 2 , stream);
302316 } else if (dimension_ <= 4 ) {
303- lookuper_ .Backward (ComputeEVGradFn<TKey, TValue, combiner, 4 >,
304- batch_size, 4 , stream);
317+ lookuper .Backward (ComputeEVGradFn<TKey, TValue, combiner, 4 >,
318+ batch_size, 4 , stream);
305319 } else if (dimension_ <= 8 ) {
306- lookuper_ .Backward (ComputeEVGradFn<TKey, TValue, combiner, 8 >,
307- batch_size, 8 , stream);
320+ lookuper .Backward (ComputeEVGradFn<TKey, TValue, combiner, 8 >,
321+ batch_size, 8 , stream);
308322 } else if (dimension_ <= 16 ) {
309- lookuper_ .Backward (ComputeEVGradFn<TKey, TValue, combiner, 16 >,
310- batch_size, 16 , stream);
323+ lookuper .Backward (ComputeEVGradFn<TKey, TValue, combiner, 16 >,
324+ batch_size, 16 , stream);
311325 } else if (dimension_ <= 32 ) {
312- lookuper_ .Backward (ComputeEVGradFn<TKey, TValue, combiner, 32 >,
313- batch_size, 32 , stream);
326+ lookuper .Backward (ComputeEVGradFn<TKey, TValue, combiner, 32 >,
327+ batch_size, 32 , stream);
314328 } else {
315- lookuper_ .Backward (NormalComputeEVGradFn<TKey, TValue, combiner>,
316- batch_size, 64 , stream);
329+ lookuper .Backward (NormalComputeEVGradFn<TKey, TValue, combiner>,
330+ batch_size, dimension_ , stream);
317331 }
318332 } else {
319333 if (dimension_ <= 2 ) {
320- lookuper_ .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 2 >,
321- batch_size, 2 , stream);
334+ lookuper .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 2 >,
335+ batch_size, 2 , stream);
322336 } else if (dimension_ <= 4 ) {
323- lookuper_ .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 4 >,
324- batch_size, 4 , stream);
337+ lookuper .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 4 >,
338+ batch_size, 4 , stream);
325339 } else if (dimension_ <= 8 ) {
326- lookuper_ .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 8 >,
327- batch_size, 8 , stream);
340+ lookuper .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 8 >,
341+ batch_size, 8 , stream);
328342 } else if (dimension_ <= 16 ) {
329- lookuper_ .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 16 >,
330- batch_size, 16 , stream);
343+ lookuper .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 16 >,
344+ batch_size, 16 , stream);
331345 } else if (dimension_ <= 32 ) {
332- lookuper_ .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 32 >,
333- batch_size, 32 , stream);
346+ lookuper .Backward (ComputeSparseGradFn<TKey, TValue, combiner, 32 >,
347+ batch_size, 32 , stream);
334348 } else {
335- lookuper_ .Backward (NormalComputeSparseGradFn<TKey, TValue, combiner>,
336- batch_size, 64 , stream);
349+ lookuper .Backward (NormalComputeSparseGradFn<TKey, TValue, combiner>,
350+ batch_size, dimension_ , stream);
337351 }
338352 }
339353 }
340354
341355 protected:
342- GroupEmbeddingLookupBackWard<TKey, TValue> lookuper_;
343356 std::string combiner_;
344357 float max_norm_;
345358 int num_lookups_;
0 commit comments