Skip to content

Commit 06f81cc

Browse files
authored
[Embedding] Refactor the data structure of EmbeddingVariable. (#924)
Signed-off-by: lixy9474 <[email protected]>
1 parent 29ecde4 commit 06f81cc

File tree

62 files changed

+3060
-3738
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+3060
-3738
lines changed

tensorflow/core/framework/embedding/bloom_filter_policy.h

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
3535
using FilterPolicy<K, V, EV>::config_;
3636

3737
public:
38-
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev) :
39-
FilterPolicy<K, V, EV>(config, ev) {
40-
38+
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev,
39+
embedding::FeatureDescriptor<V>* feat_desc)
40+
: feat_desc_(feat_desc),
41+
FilterPolicy<K, V, EV>(config, ev) {
4142
switch (config_.counter_type){
4243
case DT_UINT64:
4344
VLOG(2) << "The type of bloom counter is uint64";
@@ -64,10 +65,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
6465

6566
Status Lookup(K key, V* val, const V* default_value_ptr,
6667
const V* default_value_no_permission) override {
67-
ValuePtr<V>* value_ptr = nullptr;
68+
void* value_ptr = nullptr;
6869
Status s = ev_->LookupKey(key, &value_ptr);
6970
if (s.ok()) {
70-
V* mem_val = ev_->LookupOrCreateEmb(value_ptr, default_value_ptr);
71+
V* mem_val = feat_desc_->GetEmbedding(value_ptr, config_.emb_index);
7172
memcpy(val, mem_val, sizeof(V) * ev_->ValueLen());
7273
} else {
7374
memcpy(val, default_value_no_permission, sizeof(V) * ev_->ValueLen());
@@ -81,17 +82,17 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
8182
int64 num_of_keys,
8283
V* default_value_ptr,
8384
V* default_value_no_permission) override {
84-
std::vector<ValuePtr<V>*> value_ptr_list(num_of_keys, nullptr);
85+
std::vector<void*> value_ptr_list(num_of_keys, nullptr);
8586
ev_->BatchLookupKey(ctx, keys, value_ptr_list.data(), num_of_keys);
8687
std::vector<V*> embedding_ptr(num_of_keys, nullptr);
8788
auto do_work = [this, value_ptr_list, &embedding_ptr,
8889
default_value_ptr, default_value_no_permission]
8990
(int64 start, int64 limit) {
9091
for (int i = start; i < limit; i++) {
91-
ValuePtr<V>* value_ptr = value_ptr_list[i];
92+
void* value_ptr = value_ptr_list[i];
9293
if (value_ptr != nullptr) {
9394
embedding_ptr[i] =
94-
ev_->LookupOrCreateEmb(value_ptr, default_value_ptr);
95+
feat_desc_->GetEmbedding(value_ptr, config_.emb_index);
9596
} else {
9697
embedding_ptr[i] = default_value_no_permission;
9798
}
@@ -109,13 +110,13 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
109110
}
110111

111112
void BatchLookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& ctx,
112-
const K* keys, ValuePtr<V>** value_ptrs_list,
113+
const K* keys, void** value_ptrs_list,
113114
int64 num_of_keys) {
114115
int num_worker_threads = ctx.worker_threads->num_threads;
115116
std::vector<std::vector<K>> lookup_or_create_ids(num_worker_threads);
116117
std::vector<std::vector<int>>
117118
lookup_or_create_cursor(num_worker_threads);
118-
std::vector<std::vector<ValuePtr<V>*>>
119+
std::vector<std::vector<void*>>
119120
lookup_or_create_ptrs(num_worker_threads);
120121
IntraThreadCopyIdAllocator thread_copy_id_alloc(num_worker_threads);
121122
std::vector<std::list<int64>>
@@ -147,7 +148,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
147148
1000, do_work);
148149

149150
std::vector<K> total_ids(num_of_keys);
150-
std::vector<ValuePtr<V>*> total_ptrs(num_of_keys);
151+
std::vector<void*> total_ptrs(num_of_keys);
151152
std::vector<int> total_cursors(num_of_keys);
152153
int num_of_admit_id = 0;
153154
for (int i = 0; i < num_worker_threads; i++) {
@@ -157,7 +158,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
157158
sizeof(K) * lookup_or_create_ids[i].size());
158159
memcpy(total_ptrs.data() + num_of_admit_id,
159160
lookup_or_create_ptrs[i].data(),
160-
sizeof(ValuePtr<V>*) * lookup_or_create_ptrs[i].size());
161+
sizeof(void*) * lookup_or_create_ptrs[i].size());
161162
memcpy(total_cursors.data() + num_of_admit_id,
162163
lookup_or_create_cursor[i].data(),
163164
sizeof(int) * lookup_or_create_cursor[i].size());
@@ -174,31 +175,40 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
174175
#endif //GOOGLE_CUDA
175176

176177
void LookupOrCreate(K key, V* val, const V* default_value_ptr,
177-
ValuePtr<V>** value_ptr, int count,
178+
void** value_ptr, int count,
178179
const V* default_value_no_permission) override {
179180
if (GetBloomFreq(key) >= config_.filter_freq) {
180-
TF_CHECK_OK(ev_->LookupOrCreateKey(key, value_ptr));
181-
V* mem_val = ev_->LookupOrCreateEmb(*value_ptr, default_value_ptr);
181+
bool is_filter = true;
182+
TF_CHECK_OK(LookupOrCreateKey(key, value_ptr, &is_filter, count));
183+
V* mem_val = feat_desc_->GetEmbedding(*value_ptr, config_.emb_index);
182184
memcpy(val, mem_val, sizeof(V) * ev_->ValueLen());
183185
} else {
184186
AddFreq(key, count);
185187
memcpy(val, default_value_no_permission, sizeof(V) * ev_->ValueLen());
186188
}
187189
}
188190

189-
Status LookupOrCreateKey(K key, ValuePtr<V>** val,
191+
Status LookupOrCreateKey(K key, void** value_ptr,
190192
bool* is_filter, int64 count) override {
191-
*val = nullptr;
192-
if ((GetFreq(key, *val) + count) >= config_.filter_freq) {
193+
*value_ptr = nullptr;
194+
if ((GetFreq(key, *value_ptr) + count) >= config_.filter_freq) {
195+
Status s = ev_->LookupKey(key, value_ptr);
196+
if (!s.ok()) {
197+
*value_ptr = feat_desc_->Allocate();
198+
feat_desc_->SetDefaultValue(*value_ptr, key);
199+
ev_->storage()->Insert(key, value_ptr);
200+
s = Status::OK();
201+
}
193202
*is_filter = true;
194-
return ev_->LookupOrCreateKey(key, val);
203+
feat_desc_->AddFreq(*value_ptr, count);
204+
} else {
205+
*is_filter = false;
206+
AddFreq(key, count);
195207
}
196-
*is_filter = false;
197-
AddFreq(key, count);
198208
return Status::OK();
199209
}
200210

201-
int64 GetFreq(K key, ValuePtr<V>*) override {
211+
int64 GetFreq(K key, void* val) override {
202212
return GetBloomFreq(key);
203213
}
204214

@@ -210,7 +220,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
210220
return bloom_counter_;
211221
}
212222

213-
bool is_admit(K key, ValuePtr<V>* value_ptr) override {
223+
bool is_admit(K key, void* value_ptr) override {
214224
if (value_ptr == nullptr) {
215225
return false;
216226
} else {
@@ -326,8 +336,12 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
326336
LOG(INFO) << "skip EV key:" << *(key_buff + i);
327337
continue;
328338
}
329-
ValuePtr<V>* value_ptr = nullptr;
339+
void* value_ptr = nullptr;
330340
int64 new_freq = freq_buff[i];
341+
int64 import_version = -1;
342+
if (config_.steps_to_live != 0 || config_.record_version) {
343+
import_version = version_buff[i];
344+
}
331345
if (!is_filter) {
332346
if (freq_buff[i] >= config_.filter_freq) {
333347
SetBloomFreq(key_buff[i], freq_buff[i]);
@@ -339,17 +353,9 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
339353
SetBloomFreq(key_buff[i], freq_buff[i]);
340354
}
341355
if (new_freq >= config_.filter_freq){
342-
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
343-
if (config_.steps_to_live != 0 || config_.record_version) {
344-
value_ptr->SetStep(version_buff[i]);
345-
}
346-
if (!is_filter){
347-
ev_->LookupOrCreateEmb(value_ptr,
348-
value_buff + i * ev_->ValueLen());
349-
} else {
350-
ev_->LookupOrCreateEmb(value_ptr,
351-
ev_->GetDefaultValue(key_buff[i]));
352-
}
356+
ev_->storage()->Import(key_buff[i],
357+
value_buff + i * ev_->ValueLen(),
358+
new_freq, import_version, config_.emb_index);
353359
}
354360
}
355361
return Status::OK();
@@ -449,6 +455,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
449455
}
450456
private:
451457
void* bloom_counter_;
458+
embedding::FeatureDescriptor<V>* feat_desc_;
452459
std::vector<int64> seeds_;
453460
};
454461
} // tensorflow

tensorflow/core/framework/embedding/config.proto

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@ enum EmbeddingVariableType {
5050
enum ValuePtrStatus {
5151
OK = 0;
5252
IS_DELETED = 1;
53-
}
54-
55-
enum ValuePosition {
56-
IN_DRAM = 0;
57-
NOT_IN_DRAM = 1;
53+
NOT_IN_DRAM = 2;
5854
}
5955

6056
enum IsSetInitialized {

0 commit comments

Comments
 (0)