Skip to content

Commit 56cc51e

Browse files
authored
[Embedding] Refactor the restore interface of EmbeddingVariable. (#903)
support restore parameters from single or partitioned EmbeddingVariable Signed-off-by: JunqiHu <[email protected]>
1 parent 96d66ab commit 56cc51e

21 files changed

+1754
-1712
lines changed

tensorflow/core/framework/embedding/bloom_filter_policy.h

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ const static std::vector<int64> default_seeds = {
3131

3232
template<typename K, typename V, typename EV>
3333
class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
34+
using FilterPolicy<K, V, EV>::ev_;
35+
using FilterPolicy<K, V, EV>::config_;
36+
3437
public:
35-
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev)
36-
: config_(config), ev_(ev) {
38+
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev) :
39+
FilterPolicy<K, V, EV>(config, ev) {
40+
3741
switch (config_.counter_type){
3842
case DT_UINT64:
3943
VLOG(2) << "The type of bloom counter is uint64";
@@ -303,16 +307,18 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
303307
}
304308
}
305309

306-
Status Import(RestoreBuffer& restore_buff,
307-
int64 key_num,
308-
int bucket_num,
309-
int64 partition_id,
310-
int64 partition_num,
311-
bool is_filter) override {
310+
Status Restore(int64 key_num, int bucket_num, int64 partition_id,
311+
int64 partition_num, int64 value_len, bool is_filter,
312+
bool to_dram, bool is_incr, RestoreBuffer& restore_buff) override {
312313
K* key_buff = (K*)restore_buff.key_buffer;
313314
V* value_buff = (V*)restore_buff.value_buffer;
314315
int64* version_buff = (int64*)restore_buff.version_buffer;
315316
int64* freq_buff = (int64*)restore_buff.freq_buffer;
317+
if (to_dram) {
318+
LOG(FATAL)<<"BloomFilter dosen't support ImportToDRAM";
319+
return Status::OK();
320+
}
321+
316322
for (auto i = 0; i < key_num; ++i) {
317323
// this can describe by graph(Mod + DynamicPartition),
318324
// but memory waste and slow
@@ -333,33 +339,19 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
333339
SetBloomFreq(key_buff[i], freq_buff[i]);
334340
}
335341
if (new_freq >= config_.filter_freq){
336-
ev_->CreateKey(key_buff[i], &value_ptr);
342+
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
337343
if (config_.steps_to_live != 0 || config_.record_version) {
338344
value_ptr->SetStep(version_buff[i]);
339345
}
340346
if (!is_filter){
341347
ev_->LookupOrCreateEmb(value_ptr,
342-
value_buff + i * ev_->ValueLen());
348+
value_buff + i * ev_->ValueLen());
343349
} else {
344350
ev_->LookupOrCreateEmb(value_ptr,
345-
ev_->GetDefaultValue(key_buff[i]));
351+
ev_->GetDefaultValue(key_buff[i]));
346352
}
347353
}
348354
}
349-
if (ev_->IsMultiLevel() && !ev_->IsUseHbm() && config_.is_primary()) {
350-
ev_->UpdateCache(key_buff, key_num, version_buff, freq_buff);
351-
}
352-
return Status::OK();
353-
}
354-
355-
Status ImportToDram(RestoreBuffer& restore_buff,
356-
int64 key_num,
357-
int bucket_num,
358-
int64 partition_id,
359-
int64 partition_num,
360-
bool is_filter,
361-
V* default_values) override {
362-
LOG(FATAL)<<"BloomFilter dosen't support ImportToDRAM";
363355
return Status::OK();
364356
}
365357

@@ -455,11 +447,8 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
455447
}
456448
}
457449
}
458-
459450
private:
460451
void* bloom_counter_;
461-
EmbeddingConfig config_;
462-
EV* ev_;
463452
std::vector<int64> seeds_;
464453
};
465454
} // tensorflow

tensorflow/core/framework/embedding/counter_filter_policy.h

Lines changed: 14 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@ namespace tensorflow {
2323

2424
template<typename K, typename V, typename EV>
2525
class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
26+
using FilterPolicy<K, V, EV>::ev_;
27+
using FilterPolicy<K, V, EV>::config_;
28+
using FilterPolicy<K, V, EV>::LookupOrCreateEmbInternal;
29+
2630
public:
27-
CounterFilterPolicy(const EmbeddingConfig& config, EV* ev)
28-
: config_(config), ev_(ev){
29-
}
31+
CounterFilterPolicy(const EmbeddingConfig& config, EV* ev) :
32+
FilterPolicy<K, V, EV>(config, ev) {}
3033

3134
Status Lookup(K key, V* val, const V* default_value_ptr,
3235
const V* default_value_no_permission) override {
@@ -115,60 +118,13 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
115118
return value_ptr->GetFreq();
116119
}
117120

118-
Status Import(RestoreBuffer& restore_buff,
119-
int64 key_num,
120-
int bucket_num,
121-
int64 partition_id,
122-
int64 partition_num,
123-
bool is_filter) override {
124-
K* key_buff = (K*)restore_buff.key_buffer;
125-
V* value_buff = (V*)restore_buff.value_buffer;
126-
int64* version_buff = (int64*)restore_buff.version_buffer;
127-
int64* freq_buff = (int64*)restore_buff.freq_buffer;
128-
for (auto i = 0; i < key_num; ++i) {
129-
// this can describe by graph(Mod + DynamicPartition),
130-
// but memory waste and slow
131-
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
132-
LOG(INFO) << "skip EV key:" << *(key_buff + i);
133-
continue;
134-
}
135-
ValuePtr<V>* value_ptr = nullptr;
136-
ev_->CreateKey(key_buff[i], &value_ptr);
137-
if (!is_filter) {
138-
if (freq_buff[i] >= config_.filter_freq) {
139-
value_ptr->SetFreq(freq_buff[i]);
140-
} else {
141-
value_ptr->SetFreq(config_.filter_freq);
142-
}
143-
} else {
144-
value_ptr->SetFreq(freq_buff[i]);
145-
}
146-
if (config_.steps_to_live != 0 || config_.record_version) {
147-
value_ptr->SetStep(version_buff[i]);
148-
}
149-
if (value_ptr->GetFreq() >= config_.filter_freq) {
150-
if (!is_filter) {
151-
ev_->LookupOrCreateEmb(value_ptr,
152-
value_buff + i * ev_->ValueLen());
153-
} else {
154-
ev_->LookupOrCreateEmb(value_ptr,
155-
ev_->GetDefaultValue(key_buff[i]));
156-
}
157-
}
158-
}
159-
if (ev_->IsMultiLevel() && !ev_->IsUseHbm() && config_.is_primary()) {
160-
ev_->UpdateCache(key_buff, key_num, version_buff, freq_buff);
161-
}
162-
return Status::OK();
121+
bool is_admit(K key, ValuePtr<V>* value_ptr) override {
122+
return (GetFreq(key, value_ptr) >= config_.filter_freq);
163123
}
164124

165-
Status ImportToDram(RestoreBuffer& restore_buff,
166-
int64 key_num,
167-
int bucket_num,
168-
int64 partition_id,
169-
int64 partition_num,
170-
bool is_filter,
171-
V* default_values) override {
125+
Status Restore(int64 key_num, int bucket_num, int64 partition_id,
126+
int64 partition_num, int64 value_len, bool is_filter,
127+
bool to_dram, bool is_incr, RestoreBuffer& restore_buff) override {
172128
K* key_buff = (K*)restore_buff.key_buffer;
173129
V* value_buff = (V*)restore_buff.value_buffer;
174130
int64* version_buff = (int64*)restore_buff.version_buffer;
@@ -181,7 +137,7 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
181137
continue;
182138
}
183139
ValuePtr<V>* value_ptr = nullptr;
184-
ev_->CreateKeyOnDram(key_buff[i], &value_ptr);
140+
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
185141
if (!is_filter) {
186142
if (freq_buff[i] >= config_.filter_freq) {
187143
value_ptr->SetFreq(freq_buff[i]);
@@ -195,28 +151,12 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
195151
value_ptr->SetStep(version_buff[i]);
196152
}
197153
if (value_ptr->GetFreq() >= config_.filter_freq) {
198-
if (!is_filter) {
199-
ev_->LookupOrCreateEmb(value_ptr,
200-
value_buff + i * ev_->ValueLen(), ev_allocator());
201-
} else {
202-
ev_->LookupOrCreateEmb(value_ptr,
203-
default_values +
204-
(key_buff[i] % config_.default_value_dim)
205-
* ev_->ValueLen(),
206-
ev_allocator());
207-
}
154+
LookupOrCreateEmbInternal(is_filter, to_dram, i, value_len,
155+
value_ptr, value_buff, key_buff);
208156
}
209157
}
210158
return Status::OK();
211159
}
212-
213-
bool is_admit(K key, ValuePtr<V>* value_ptr) override {
214-
return (GetFreq(key, value_ptr) >= config_.filter_freq);
215-
}
216-
217-
private:
218-
EmbeddingConfig config_;
219-
EV* ev_;
220160
};
221161

222162
} // tensorflow

tensorflow/core/framework/embedding/dram_leveldb_storage.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
6868
}
6969

7070
void Insert(K key, ValuePtr<V>** value_ptr,
71-
size_t alloc_len) override {
71+
size_t alloc_len, bool to_dram = false) override {
7272
dram_->Insert(key, value_ptr, alloc_len);
7373
}
74-
7574
Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
7675
size_t size, CopyBackFlag &need_copyback) override {
7776
LOG(FATAL)<<"GetOrCreate(K key, ValuePtr<V>** value_ptr, "
@@ -112,12 +111,6 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
112111
return false;
113112
}
114113

115-
bool IsUsePersistentStorage() override {
116-
/*The return value is set to false temporarily,
117-
because the corresponding interface is not implemented.*/
118-
return false;
119-
}
120-
121114
void iterator_mutex_lock() override {
122115
leveldb_->get_mutex()->lock();
123116
}

tensorflow/core/framework/embedding/dram_pmem_storage.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,9 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
7676
}
7777

7878
void Insert(K key, ValuePtr<V>** value_ptr,
79-
size_t alloc_len) override {
79+
size_t alloc_len, bool to_dram = false) override {
8080
dram_->Insert(key, value_ptr, alloc_len);
8181
}
82-
8382
Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
8483
size_t size, CopyBackFlag &need_copyback) override {
8584
LOG(FATAL)<<"GetOrCreate(K key, ValuePtr<V>** value_ptr, "
@@ -95,12 +94,6 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
9594
return false;
9695
}
9796

98-
bool IsUsePersistentStorage() override {
99-
/*The return value is set to false temporarily,
100-
because the corresponding interface is not implemented.*/
101-
return false;
102-
}
103-
10497
Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
10598
size_t size) override {
10699
Status s = dram_->Get(key, value_ptr);

tensorflow/core/framework/embedding/dram_ssd_storage.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
6969
}
7070

7171
void Insert(K key, ValuePtr<V>** value_ptr,
72-
size_t alloc_len) override {
72+
size_t alloc_len, bool to_dram = false) override {
7373
dram_->Insert(key, value_ptr, alloc_len);
7474
}
7575

@@ -210,27 +210,27 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
210210
return key_list->size() + ssd_rec_desc->key_list.size();
211211
}
212212

213-
void RestoreSsdHashmap(
214-
K* key_list, int64* key_file_id_list,
215-
int64* key_offset_list, int64 num_of_keys,
216-
int64* file_list, int64* invalid_record_count_list,
217-
int64* record_count_list, int64 num_of_files,
218-
const std::string& ssd_emb_file_name) override {
213+
Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len,
214+
const std::string& ssd_emb_file_name, EmbeddingVar<K, V>* ev,
215+
RestoreSSDBuffer<K>& restore_buff) override {
216+
int64 alloc_len = Storage<K, V>::ComputeAllocLen(value_len);
219217
std::map<int64, int64> file_id_map;
220-
for (int64 i = 0; i < num_of_files; i++) {
221-
file_id_map[file_list[i]] = i;
218+
for (int64 i = 0; i < restore_buff.num_of_files; i++) {
219+
file_id_map[restore_buff.file_list_buf[i]] = i;
222220
}
223221

224-
ssd_hash_->CopyEmbFilesFromCkpt(
225-
file_list, invalid_record_count_list,
226-
record_count_list, num_of_files,
227-
ssd_emb_file_name);
228-
229-
ssd_hash_->Import(key_list, key_file_id_list,
230-
key_offset_list, num_of_keys,
231-
file_id_map);
222+
ssd_hash_->CopyEmbFilesFromCkpt(restore_buff.file_list_buf,
223+
restore_buff.invalid_record_count_list_buf,
224+
restore_buff.record_count_list_buf,
225+
restore_buff.num_of_files,
226+
ssd_emb_file_name);
227+
228+
ssd_hash_->Import(restore_buff.key_list_buf,
229+
restore_buff.key_file_id_list_buf,
230+
restore_buff.key_offset_list_buf,
231+
restore_buff.num_of_keys,
232+
file_id_map);
232233
}
233-
234234
Status Eviction(K* evict_ids, int64 evict_size) override {
235235
ValuePtr<V>* value_ptr = nullptr;
236236
for (int64 i = 0; i < evict_size; ++i) {

0 commit comments

Comments
 (0)