Skip to content

Commit d350de3

Browse files
authored
[Embedding] Fix sp_weights indices calculation error. (#801)
1. Fix sp_weights indices wrong indices calculation. 2. Fix wrong cuda launch params while dim greater than 64. Signed-off-by: JunqiHu <[email protected]>
1 parent 0e3c331 commit d350de3

File tree

5 files changed

+477
-427
lines changed

5 files changed

+477
-427
lines changed

tensorflow/core/framework/embedding/single_tier_storage.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class HbmStorage : public SingleTierStorage<K, V> {
469469
embedding::Iterator** it) override {
470470
GPUHashMapKV<K, V>* gpu_kv =
471471
dynamic_cast<GPUHashMapKV<K, V>*>(SingleTierStorage<K, V>::kv_);
472-
gpu_kv->GetSnapshot(key_list, value_list, emb_config.emb_index);
472+
gpu_kv->GetSnapshot(key_list, value_list, emb_config);
473473
return key_list->size();
474474
}
475475

tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_backward_base_ops.cu.h

Lines changed: 141 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ namespace {
2727

2828
template <typename TKey, typename TValue>
2929
struct GroupEmbeddingBackWardArgs {
30+
GroupEmbeddingBackWardArgs() = default;
31+
GroupEmbeddingBackWardArgs(TValue *grads, TKey *sp_values,
32+
TValue *emb_variable, TValue *grads_output,
33+
int *offset_indices, int nnz)
34+
: grads_(grads), sp_values_(sp_values),emb_variable_(emb_variable),
35+
grads_output_(grads_output), offset_indices_(offset_indices), nnz_(nnz) {}
3036
TValue *grads_;
3137
TKey *sp_values_;
3238
TValue *emb_variable_;
@@ -59,26 +65,29 @@ __global__ void ComputeEVGradFn(
5965
feature_num = args[idx].offset_indices_[bid + 1] - value_offset;
6066
}
6167

62-
float grad = args[idx].grads_[bid * dimension + tid];
63-
grad = CombineGrad<combiner>(grad, feature_num);
64-
65-
for (int j = 0; j < feature_num; ++j) {
66-
float grad_i = grad;
67-
int feature_offset = (value_offset + j) * dimension;
68-
if (max_norm > 0.0f) {
69-
float emb_element = 0.0f; // TODO: hujunqi get emb_weight
70-
if (tid == 0) {
71-
l2_sum = 0.0f;
72-
}
73-
tile.shfl(l2_sum, 0);
74-
atomicAdd(&l2_sum, emb_element * emb_element);
75-
tile.sync();
76-
float l2_norm = sqrtf(l2_sum);
77-
if (l2_norm > max_norm) {
78-
grad_i *= max_norm / l2_norm;
68+
if (feature_num > 0) {
69+
float grad = args[idx].grads_[bid * dimension + tid];
70+
grad = CombineGrad<combiner>(grad, feature_num);
71+
72+
for (int j = 0; j < feature_num; ++j) {
73+
float grad_i = grad;
74+
int feature_offset = (value_offset + j) * dimension;
75+
if (max_norm > 0.0f) {
76+
float emb_element = 0.0f; // TODO(junqihu): get emb_weight
77+
if (tid == 0) {
78+
l2_sum = 0.0f;
79+
}
80+
tile.shfl(l2_sum, 0);
81+
atomicAdd(&l2_sum, emb_element * emb_element);
82+
tile.sync();
83+
float l2_norm = sqrtf(l2_sum);
84+
if (l2_norm > max_norm) {
85+
grad_i *= max_norm / l2_norm;
86+
}
7987
}
88+
args[idx].grads_output_[(value_offset + j) * dimension + tid] =
89+
grad_i;
8090
}
81-
args[idx].grads_output_[(value_offset + j) * dimension + tid] = grad_i;
8291
}
8392
}
8493
}
@@ -105,31 +114,29 @@ __global__ void ComputeSparseGradFn(
105114
} else {
106115
feature_num = args[idx].offset_indices_[bid + 1] - value_offset;
107116
}
108-
float grad = args[idx].grads_[bid * dimension + tid];
109-
// printf("feature_num is %d , grad is %lld , bid is %d , tid is %d \n",
110-
// feature_num, grad, blockIdx.x, threadIdx.x);
111-
grad = CombineGrad<combiner>(grad, feature_num);
112-
for (int i = 0; i < feature_num; i++) {
113-
float grad_i = grad;
114-
if (max_norm > 0.0f) {
115-
int64_t indices = int(args[idx].sp_values_[value_offset + i]);
116-
float emb_element = args[idx].emb_variable_[indices * dimension + tid];
117-
// if (FastBoundsCheck(indices, args[idx].emb_row_size_)) {
118-
// emb_element =
119-
// args[idx].emb_variable_[indices * dimension + tid];
120-
// }
121-
if (tid == 0) {
122-
l2_sum = 0.0f;
123-
}
124-
tile.shfl(l2_sum, 0);
125-
atomicAdd(&l2_sum, emb_element * emb_element);
126-
tile.sync();
127-
float l2_norm = sqrtf(l2_sum);
128-
if (l2_norm > max_norm) {
129-
grad_i *= max_norm / l2_norm;
117+
118+
if (feature_num > 0) {
119+
float grad = args[idx].grads_[bid * dimension + tid];
120+
grad = CombineGrad<combiner>(grad, feature_num);
121+
for (int i = 0; i < feature_num; i++) {
122+
float grad_i = grad;
123+
if (max_norm > 0.0f) {
124+
int64_t indices = int(args[idx].sp_values_[value_offset + i]);
125+
float emb_element =
126+
args[idx].emb_variable_[indices * dimension + tid];
127+
if (tid == 0) {
128+
l2_sum = 0.0f;
129+
}
130+
tile.shfl(l2_sum, 0);
131+
atomicAdd(&l2_sum, emb_element * emb_element);
132+
tile.sync();
133+
float l2_norm = sqrtf(l2_sum);
134+
if (l2_norm > max_norm) {
135+
grad_i *= max_norm / l2_norm;
136+
}
130137
}
138+
args[idx].grads_output_[(value_offset + i) * dimension + tid] = grad_i;
131139
}
132-
args[idx].grads_output_[(value_offset + i) * dimension + tid] = grad_i;
133140
}
134141
}
135142
}
@@ -156,26 +163,29 @@ __global__ void NormalComputeEVGradFn(
156163
feature_num = args[idx].offset_indices_[bid + 1] - value_offset;
157164
}
158165

159-
float grad = args[idx].grads_[bid * dimension + tid];
160-
grad = CombineGrad<combiner>(grad, feature_num);
161-
162-
for (int j = 0; j < feature_num; ++j) {
163-
float grad_i = grad;
164-
int feature_offset = (value_offset + j) * dimension;
165-
if (max_norm > 0.0f) {
166-
float emb_element = 0.0f; // TODO: hujunqi get emb_weight
167-
if (tid == 0) {
168-
l2_sum[0] = 0.0f;
169-
}
170-
__syncthreads();
171-
atomicAdd(l2_sum, emb_element * emb_element);
172-
__syncthreads();
173-
float l2_norm = sqrtf(l2_sum[0]);
174-
if (l2_norm > max_norm) {
175-
grad_i *= max_norm / l2_norm;
166+
if (feature_num > 0) {
167+
float grad = args[idx].grads_[bid * dimension + tid];
168+
grad = CombineGrad<combiner>(grad, feature_num);
169+
170+
for (int j = 0; j < feature_num; ++j) {
171+
float grad_i = grad;
172+
int feature_offset = (value_offset + j) * dimension;
173+
if (max_norm > 0.0f) {
174+
float emb_element = 0.0f; // TODO(junqihu): get emb_weight
175+
if (tid == 0) {
176+
l2_sum[0] = 0.0f;
177+
}
178+
__syncthreads();
179+
atomicAdd(l2_sum, emb_element * emb_element);
180+
__syncthreads();
181+
float l2_norm = sqrtf(l2_sum[0]);
182+
if (l2_norm > max_norm) {
183+
grad_i *= max_norm / l2_norm;
184+
}
176185
}
186+
args[idx].grads_output_[(value_offset + j) * dimension + tid] =
187+
grad_i;
177188
}
178-
args[idx].grads_output_[(value_offset + j) * dimension + tid] = grad_i;
179189
}
180190
}
181191
}
@@ -201,27 +211,29 @@ __global__ void NormalComputeSparseGradFn(
201211
} else {
202212
feature_num = args[idx].offset_indices_[bid + 1] - value_offset;
203213
}
204-
float grad = args[idx].grads_[bid * dimension + tid];
205-
// printf("feature_num is %d , grad is %lld , bid is %d , tid is %d \n",
206-
// feature_num, grad, blockIdx.x, threadIdx.x);
207-
grad = CombineGrad<combiner>(grad, feature_num);
208-
for (int i = 0; i < feature_num; i++) {
209-
float grad_i = grad;
210-
if (max_norm > 0.0f) {
211-
int64_t indices = int(args[idx].sp_values_[value_offset + i]);
212-
float emb_element = args[idx].emb_variable_[indices * dimension + tid];
213-
if (tid == 0) {
214-
l2_sum[0] = 0.0f;
215-
}
216-
__syncthreads();
217-
atomicAdd(l2_sum, emb_element * emb_element);
218-
__syncthreads();
219-
float l2_norm = sqrtf(l2_sum[0]);
220-
if (l2_norm > max_norm) {
221-
grad_i *= max_norm / l2_norm;
214+
215+
if (feature_num > 0) {
216+
float grad = args[idx].grads_[bid * dimension + tid];
217+
grad = CombineGrad<combiner>(grad, feature_num);
218+
for (int i = 0; i < feature_num; i++) {
219+
float grad_i = grad;
220+
if (max_norm > 0.0f) {
221+
int64_t indices = int(args[idx].sp_values_[value_offset + i]);
222+
float emb_element =
223+
args[idx].emb_variable_[indices * dimension + tid];
224+
if (tid == 0) {
225+
l2_sum[0] = 0.0f;
226+
}
227+
__syncthreads();
228+
atomicAdd(l2_sum, emb_element * emb_element);
229+
__syncthreads();
230+
float l2_norm = sqrtf(l2_sum[0]);
231+
if (l2_norm > max_norm) {
232+
grad_i *= max_norm / l2_norm;
233+
}
222234
}
235+
args[idx].grads_output_[(value_offset + i) * dimension + tid] = grad_i;
223236
}
224-
args[idx].grads_output_[(value_offset + i) * dimension + tid] = grad_i;
225237
}
226238
}
227239
}
@@ -231,52 +243,54 @@ __global__ void NormalComputeSparseGradFn(
231243
template <typename TKey, typename TValue>
232244
class GroupEmbeddingLookupBackWard {
233245
public:
234-
void initialize(int dimension, int num_lookups, float max_norm) {
235-
CK_CUDA_THROW_(cudaMalloc(
236-
&d_args_,
237-
sizeof(GroupEmbeddingBackWardArgs<TKey, TValue>) * num_lookups));
238-
args_.resize(num_lookups);
246+
explicit GroupEmbeddingLookupBackWard(int dimension, int num_lookups,
247+
float max_norm,
248+
Allocator *gpu_allocator = nullptr)
249+
: alloc_(gpu_allocator) {
250+
d_args_ =
251+
TypedAllocator::Allocate<GroupEmbeddingBackWardArgs<TKey, TValue>>(
252+
gpu_allocator, num_lookups, AllocationAttributes());
253+
h_args_.reserve(num_lookups);
239254
max_norm_ = max_norm;
240255
nums_ = num_lookups;
241256
dimension_ = dimension;
242257
}
243258

244-
void set(int idx, TValue *grads, TValue *grads_output, int *offset_indices,
245-
TKey *sp_values, TValue *emb_variable, int nnz) {
246-
args_[idx].grads_ = grads;
247-
args_[idx].grads_output_ = grads_output;
248-
args_[idx].offset_indices_ = offset_indices;
249-
args_[idx].sp_values_ = sp_values;
250-
args_[idx].emb_variable_ = emb_variable;
251-
args_[idx].nnz_ = nnz;
259+
void set(GroupEmbeddingBackWardArgs<TKey, TValue> &arg) {
260+
h_args_.emplace_back(arg);
252261
}
253262

254263
~GroupEmbeddingLookupBackWard() {
255-
if (d_args_) {
256-
CK_CUDA_THROW_(cudaFree(d_args_));
257-
}
264+
TypedAllocator::Deallocate(alloc_, d_args_, nums_);
258265
}
259266

260267
template <typename GradFn>
261-
void Backward(GradFn fn, int batch_size, int tile_size, cudaStream_t stream) {
268+
inline void Backward(GradFn fn, int batch_size, int tile_size,
269+
cudaStream_t stream) {
262270
CK_CUDA_THROW_(cudaMemcpyAsync(
263-
d_args_, args_.data(),
264-
args_.size() * sizeof(GroupEmbeddingBackWardArgs<TKey, TValue>),
271+
d_args_, h_args_.data(),
272+
h_args_.size() * sizeof(GroupEmbeddingBackWardArgs<TKey, TValue>),
265273
cudaMemcpyHostToDevice, stream));
266274

267275
{
268-
const int block_size = (batch_size - 1) / 64 * tile_size + 1;
276+
if (tile_size <= 32) {
277+
const int block_size = batch_size / 64 * tile_size + 1;
269278

270-
fn<<<block_size, 64, 0, stream>>>(batch_size, max_norm_, nums_,
271-
dimension_, d_args_);
279+
fn<<<block_size, 64, 0, stream>>>(batch_size, max_norm_, nums_,
280+
dimension_, d_args_);
281+
} else {
282+
fn<<<batch_size, tile_size, 0, stream>>>(batch_size, max_norm_, nums_,
283+
dimension_, d_args_);
284+
}
272285
}
273286

274287
CK_CUDA_THROW_(cudaGetLastError());
275288
}
276289

277290
protected:
278-
std::vector<GroupEmbeddingBackWardArgs<TKey, TValue>> args_;
291+
std::vector<GroupEmbeddingBackWardArgs<TKey, TValue>> h_args_;
279292
GroupEmbeddingBackWardArgs<TKey, TValue> *d_args_;
293+
Allocator *alloc_;
280294
float max_norm_;
281295
int nums_;
282296
int dimension_;
@@ -290,56 +304,55 @@ class GroupLookupBackWardBaseOp : public OpKernel {
290304
OP_REQUIRES_OK(c, c->GetAttr("max_norm", &max_norm_));
291305
OP_REQUIRES_OK(c, c->GetAttr("num_lookups", &num_lookups_));
292306
OP_REQUIRES_OK(c, c->GetAttr("dimension", &dimension_));
293-
lookuper_.initialize(dimension_, num_lookups_, max_norm_);
294307
}
295308

296309
template <bool Isev = false, Combiner combiner>
297-
inline void compute(const int batch_size, cudaStream_t stream) {
310+
inline void compute(GroupEmbeddingLookupBackWard<TKey, TValue> &lookuper,
311+
const int batch_size, cudaStream_t stream) {
298312
if (Isev) {
299313
if (dimension_ <= 2) {
300-
lookuper_.Backward(ComputeEVGradFn<TKey, TValue, combiner, 2>,
301-
batch_size, 2, stream);
314+
lookuper.Backward(ComputeEVGradFn<TKey, TValue, combiner, 2>,
315+
batch_size, 2, stream);
302316
} else if (dimension_ <= 4) {
303-
lookuper_.Backward(ComputeEVGradFn<TKey, TValue, combiner, 4>,
304-
batch_size, 4, stream);
317+
lookuper.Backward(ComputeEVGradFn<TKey, TValue, combiner, 4>,
318+
batch_size, 4, stream);
305319
} else if (dimension_ <= 8) {
306-
lookuper_.Backward(ComputeEVGradFn<TKey, TValue, combiner, 8>,
307-
batch_size, 8, stream);
320+
lookuper.Backward(ComputeEVGradFn<TKey, TValue, combiner, 8>,
321+
batch_size, 8, stream);
308322
} else if (dimension_ <= 16) {
309-
lookuper_.Backward(ComputeEVGradFn<TKey, TValue, combiner, 16>,
310-
batch_size, 16, stream);
323+
lookuper.Backward(ComputeEVGradFn<TKey, TValue, combiner, 16>,
324+
batch_size, 16, stream);
311325
} else if (dimension_ <= 32) {
312-
lookuper_.Backward(ComputeEVGradFn<TKey, TValue, combiner, 32>,
313-
batch_size, 32, stream);
326+
lookuper.Backward(ComputeEVGradFn<TKey, TValue, combiner, 32>,
327+
batch_size, 32, stream);
314328
} else {
315-
lookuper_.Backward(NormalComputeEVGradFn<TKey, TValue, combiner>,
316-
batch_size, 64, stream);
329+
lookuper.Backward(NormalComputeEVGradFn<TKey, TValue, combiner>,
330+
batch_size, dimension_, stream);
317331
}
318332
} else {
319333
if (dimension_ <= 2) {
320-
lookuper_.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 2>,
321-
batch_size, 2, stream);
334+
lookuper.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 2>,
335+
batch_size, 2, stream);
322336
} else if (dimension_ <= 4) {
323-
lookuper_.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 4>,
324-
batch_size, 4, stream);
337+
lookuper.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 4>,
338+
batch_size, 4, stream);
325339
} else if (dimension_ <= 8) {
326-
lookuper_.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 8>,
327-
batch_size, 8, stream);
340+
lookuper.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 8>,
341+
batch_size, 8, stream);
328342
} else if (dimension_ <= 16) {
329-
lookuper_.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 16>,
330-
batch_size, 16, stream);
343+
lookuper.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 16>,
344+
batch_size, 16, stream);
331345
} else if (dimension_ <= 32) {
332-
lookuper_.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 32>,
333-
batch_size, 32, stream);
346+
lookuper.Backward(ComputeSparseGradFn<TKey, TValue, combiner, 32>,
347+
batch_size, 32, stream);
334348
} else {
335-
lookuper_.Backward(NormalComputeSparseGradFn<TKey, TValue, combiner>,
336-
batch_size, 64, stream);
349+
lookuper.Backward(NormalComputeSparseGradFn<TKey, TValue, combiner>,
350+
batch_size, dimension_, stream);
337351
}
338352
}
339353
}
340354

341355
protected:
342-
GroupEmbeddingLookupBackWard<TKey, TValue> lookuper_;
343356
std::string combiner_;
344357
float max_norm_;
345358
int num_lookups_;

0 commit comments

Comments
 (0)