Skip to content

Commit ce39947

Browse files
authored
[Refactor] Refactor the structure of MultiTierStorage class and its subclasses. (#856)
1. The MultiTierStorage class no longer accesses the data of the KvInterface class 2. A subclass of MultiTierStorage consists of multiple SingleTierStorage objects, instead of KvInterface objects Signed-off-by: lixy9474 <[email protected]>
1 parent 5c8eb4f commit ce39947

File tree

11 files changed

+867
-763
lines changed

11 files changed

+867
-763
lines changed

tensorflow/core/framework/embedding/dram_leveldb_storage.h

Lines changed: 116 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ limitations under the License.
1616
#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_DRAM_LEVELDB_STORAGE_H_
1717

1818
#include "tensorflow/core/framework/embedding/leveldb_kv.h"
19+
#include "tensorflow/core/framework/embedding/cpu_hash_map_kv.h"
1920
#include "tensorflow/core/framework/embedding/multi_tier_storage.h"
21+
#include "tensorflow/core/framework/embedding/single_tier_storage.h"
2022

2123
namespace tensorflow {
2224
template <class V>
@@ -31,55 +33,23 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
3133
public:
3234
DramLevelDBStore(const StorageConfig& sc, Allocator* alloc,
3335
LayoutCreator<V>* lc, const std::string& name)
34-
: alloc_(alloc), layout_creator_(lc),
35-
MultiTierStorage<K, V>(sc, name) {
36-
dram_kv_ = new LocklessHashMap<K, V>();
37-
leveldb_ = new LevelDBKV<K, V>(sc.path);
38-
if (sc.embedding_config.steps_to_live != 0) {
39-
dram_policy_ = new GlobalStepShrinkPolicy<K, V>(dram_kv_, alloc_,
40-
sc.embedding_config.slot_num + 1);
41-
leveldb_policy_ = new GlobalStepShrinkPolicy<K, V>(leveldb_, alloc_,
42-
sc.embedding_config.slot_num + 1);
43-
} else if (sc.embedding_config.l2_weight_threshold != -1.0) {
44-
dram_policy_ =
45-
new L2WeightShrinkPolicy<K, V>(
46-
sc.embedding_config.l2_weight_threshold,
47-
sc.embedding_config.primary_emb_index,
48-
Storage<K, V>::GetOffset(sc.embedding_config.primary_emb_index),
49-
dram_kv_, alloc_,
50-
sc.embedding_config.slot_num + 1);
51-
leveldb_policy_ =
52-
new L2WeightShrinkPolicy<K, V>(
53-
sc.embedding_config.l2_weight_threshold,
54-
sc.embedding_config.primary_emb_index,
55-
Storage<K, V>::GetOffset(sc.embedding_config.primary_emb_index),
56-
leveldb_, alloc_,
57-
sc.embedding_config.slot_num + 1);
58-
} else {
59-
dram_policy_ = nullptr;
60-
leveldb_policy_ = nullptr;
61-
}
62-
63-
MultiTierStorage<K, V>::kvs_.emplace_back(
64-
KVInterfaceDescriptor<K, V>(dram_kv_, alloc_, dram_mu_, dram_policy_));
65-
MultiTierStorage<K, V>::kvs_.emplace_back(
66-
KVInterfaceDescriptor<K, V>(leveldb_, alloc_,
67-
leveldb_mu_, leveldb_policy_));
36+
: MultiTierStorage<K, V>(sc, name) {
37+
dram_ = new DramStorage<K, V>(sc, alloc, lc, new LocklessHashMap<K, V>());
38+
leveldb_ = new LevelDBStore<K, V>(sc, alloc, lc);
6839
}
6940

7041
~DramLevelDBStore() override {
71-
MultiTierStorage<K, V>::ReleaseValues(
72-
{std::make_pair(dram_kv_, alloc_)});
73-
if (dram_policy_ != nullptr) delete dram_policy_;
74-
if (leveldb_policy_ != nullptr) delete leveldb_policy_;
42+
MultiTierStorage<K, V>::DeleteFromEvictionManager();
43+
delete dram_;
44+
delete leveldb_;
7545
}
7646

7747
TF_DISALLOW_COPY_AND_ASSIGN(DramLevelDBStore);
7848

7949
Status Get(K key, ValuePtr<V>** value_ptr) override {
80-
Status s = dram_kv_->Lookup(key, value_ptr);
50+
Status s = dram_->Get(key, value_ptr);
8151
if (!s.ok()) {
82-
s = leveldb_->Lookup(key, value_ptr);
52+
s = leveldb_->Get(key, value_ptr);
8353
}
8454
return s;
8555
}
@@ -89,48 +59,38 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
8959
}
9060

9161
void Insert(K key, ValuePtr<V>** value_ptr,
92-
int64 alloc_len) override {
93-
do {
94-
*value_ptr = layout_creator_->Create(alloc_, alloc_len);
95-
Status s = dram_kv_->Insert(key, *value_ptr);
96-
if (s.ok()) {
97-
break;
98-
} else {
99-
(*value_ptr)->Destroy(alloc_);
100-
delete *value_ptr;
101-
}
102-
} while (!(dram_kv_->Lookup(key, value_ptr)).ok());
62+
size_t alloc_len) override {
63+
dram_->Insert(key, value_ptr, alloc_len);
10364
}
10465

10566
Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
10667
size_t size, CopyBackFlag &need_copyback) override {
107-
need_copyback = NOT_COPYBACK;
108-
return GetOrCreate(key, value_ptr, size);
68+
LOG(FATAL)<<"GetOrCreate(K key, ValuePtr<V>** value_ptr, "
69+
<<"size_t size, CopyBackFlag &need_copyback) "
70+
<<"in DramLevelDBStore can not be called.";
10971
}
11072

11173
Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
11274
size_t size) override {
113-
Status s = dram_kv_->Lookup(key, value_ptr);
75+
Status s = dram_->Get(key, value_ptr);
11476
if (s.ok()) {
11577
return s;
11678
}
117-
s = leveldb_->Lookup(key, value_ptr);
118-
if (!s.ok()) {
119-
*value_ptr = layout_creator_->Create(alloc_, size);
120-
}
121-
122-
s = dram_kv_->Insert(key, *value_ptr);
79+
s = leveldb_->Get(key, value_ptr);
12380
if (s.ok()) {
124-
return s;
81+
s = dram_->TryInsert(key, *value_ptr);
82+
if (s.ok()) {
83+
return s;
84+
}
85+
leveldb_->DestroyValuePtr(*value_ptr);
86+
return dram_->Get(key, value_ptr);
12587
}
126-
// Insert Failed, key already exist
127-
(*value_ptr)->Destroy(alloc_);
128-
delete *value_ptr;
129-
return dram_kv_->Lookup(key, value_ptr);
88+
dram_->Insert(key, value_ptr, size);
89+
return Status::OK();
13090
}
13191

13292
Status Remove(K key) override {
133-
dram_kv_->Remove(key);
93+
dram_->Remove(key);
13494
leveldb_->Remove(key);
13595
return Status::OK();
13696
}
@@ -150,46 +110,122 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
150110
}
151111

152112
void iterator_mutex_lock() override {
153-
leveldb_mu_.lock();
113+
leveldb_->get_mutex()->lock();
154114
}
155115

156116
void iterator_mutex_unlock() override {
157-
leveldb_mu_.unlock();
117+
leveldb_->get_mutex()->unlock();
158118
}
159119

160120
int64 Size() const override {
161-
int64 total_size = dram_kv_->Size();
121+
int64 total_size = dram_->Size();
162122
total_size += leveldb_->Size();
163123
return total_size;
164124
}
165125

126+
int64 Size(int level) const override {
127+
if (level == 0) {
128+
return dram_->Size();
129+
} else if (level == 1) {
130+
return leveldb_->Size();
131+
} else {
132+
return -1;
133+
}
134+
}
135+
136+
int LookupTier(K key) const override {
137+
Status s = dram_->Contains(key);
138+
if (s.ok())
139+
return 0;
140+
s = leveldb_->Contains(key);
141+
if (s.ok())
142+
return 1;
143+
return -1;
144+
}
145+
166146
Status GetSnapshot(std::vector<K>* key_list,
167147
std::vector<ValuePtr<V>*>* value_ptr_list) override {
168148
{
169-
mutex_lock l(dram_mu_);
170-
TF_CHECK_OK(dram_kv_->GetSnapshot(key_list, value_ptr_list));
149+
mutex_lock l(*(dram_->get_mutex()));
150+
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
171151
}
172152
{
173-
mutex_lock l(leveldb_mu_);
153+
mutex_lock l(*(leveldb_->get_mutex()));
174154
TF_CHECK_OK(leveldb_->GetSnapshot(key_list, value_ptr_list));
175155
}
176156
return Status::OK();
177157
}
178158

159+
Status Shrink(int64 value_len) override {
160+
dram_->Shrink(value_len);
161+
leveldb_->Shrink(value_len);
162+
return Status::OK();
163+
}
164+
165+
Status Shrink(int64 global_step, int64 steps_to_live) override {
166+
dram_->Shrink(global_step, steps_to_live);
167+
leveldb_->Shrink(global_step, steps_to_live);
168+
return Status::OK();
169+
}
170+
171+
int64 GetSnapshot(std::vector<K>* key_list,
172+
std::vector<V* >* value_list,
173+
std::vector<int64>* version_list,
174+
std::vector<int64>* freq_list,
175+
const EmbeddingConfig& emb_config,
176+
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
177+
embedding::Iterator** it) override {
178+
{
179+
mutex_lock l(*(dram_->get_mutex()));
180+
std::vector<ValuePtr<V>*> value_ptr_list;
181+
std::vector<K> key_list_tmp;
182+
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
183+
MultiTierStorage<K, V>::SetListsForCheckpoint(
184+
key_list_tmp, value_ptr_list, emb_config,
185+
key_list, value_list, version_list, freq_list);
186+
}
187+
{
188+
mutex_lock l(*(leveldb_->get_mutex()));
189+
*it = leveldb_->GetIterator();
190+
}
191+
return key_list->size();
192+
}
193+
194+
Status Eviction(K* evict_ids, int64 evict_size) override {
195+
ValuePtr<V>* value_ptr;
196+
for (int64 i = 0; i < evict_size; ++i) {
197+
if (dram_->Get(evict_ids[i], &value_ptr).ok()) {
198+
TF_CHECK_OK(leveldb_->Commit(evict_ids[i], value_ptr));
199+
TF_CHECK_OK(dram_->Remove(evict_ids[i]));
200+
dram_->DestroyValuePtr(value_ptr);
201+
}
202+
}
203+
return Status::OK();
204+
}
205+
206+
Status EvictionWithDelayedDestroy(K* evict_ids, int64 evict_size) override {
207+
mutex_lock l(*(dram_->get_mutex()));
208+
mutex_lock l1(*(leveldb_->get_mutex()));
209+
MultiTierStorage<K, V>::ReleaseInvalidValuePtr(dram_->alloc_);
210+
ValuePtr<V>* value_ptr = nullptr;
211+
for (int64 i = 0; i < evict_size; ++i) {
212+
if (dram_->Get(evict_ids[i], &value_ptr).ok()) {
213+
TF_CHECK_OK(leveldb_->Commit(evict_ids[i], value_ptr));
214+
TF_CHECK_OK(dram_->Remove(evict_ids[i]));
215+
MultiTierStorage<K, V>::KeepInvalidValuePtr(value_ptr);
216+
}
217+
}
218+
return Status::OK();
219+
}
220+
179221
protected:
180222
void SetTotalDims(int64 total_dims) override {
181223
leveldb_->SetTotalDims(total_dims);
182224
}
183225

184226
private:
185-
KVInterface<K, V>* dram_kv_;
186-
KVInterface<K, V>* leveldb_;
187-
Allocator* alloc_;
188-
ShrinkPolicy<K, V>* dram_policy_;
189-
ShrinkPolicy<K, V>* leveldb_policy_;
190-
LayoutCreator<V>* layout_creator_;
191-
mutex dram_mu_; //must be locked before leveldb_mu_ is locked
192-
mutex leveldb_mu_;
227+
DramStorage<K, V>* dram_;
228+
LevelDBStore<K, V>* leveldb_;
193229
};
194230
} // embedding
195231
} // tensorflow

0 commit comments

Comments
 (0)