@@ -16,7 +16,9 @@ limitations under the License.
1616#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_DRAM_LEVELDB_STORAGE_H_
1717
1818#include " tensorflow/core/framework/embedding/leveldb_kv.h"
19+ #include " tensorflow/core/framework/embedding/cpu_hash_map_kv.h"
1920#include " tensorflow/core/framework/embedding/multi_tier_storage.h"
21+ #include " tensorflow/core/framework/embedding/single_tier_storage.h"
2022
2123namespace tensorflow {
2224template <class V >
@@ -31,55 +33,23 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
3133 public:
3234 DramLevelDBStore (const StorageConfig& sc, Allocator* alloc,
3335 LayoutCreator<V>* lc, const std::string& name)
34- : alloc_(alloc), layout_creator_(lc),
35- MultiTierStorage<K, V>(sc, name) {
36- dram_kv_ = new LocklessHashMap<K, V>();
37- leveldb_ = new LevelDBKV<K, V>(sc.path );
38- if (sc.embedding_config .steps_to_live != 0 ) {
39- dram_policy_ = new GlobalStepShrinkPolicy<K, V>(dram_kv_, alloc_,
40- sc.embedding_config .slot_num + 1 );
41- leveldb_policy_ = new GlobalStepShrinkPolicy<K, V>(leveldb_, alloc_,
42- sc.embedding_config .slot_num + 1 );
43- } else if (sc.embedding_config .l2_weight_threshold != -1.0 ) {
44- dram_policy_ =
45- new L2WeightShrinkPolicy<K, V>(
46- sc.embedding_config .l2_weight_threshold ,
47- sc.embedding_config .primary_emb_index ,
48- Storage<K, V>::GetOffset (sc.embedding_config .primary_emb_index ),
49- dram_kv_, alloc_,
50- sc.embedding_config .slot_num + 1 );
51- leveldb_policy_ =
52- new L2WeightShrinkPolicy<K, V>(
53- sc.embedding_config .l2_weight_threshold ,
54- sc.embedding_config .primary_emb_index ,
55- Storage<K, V>::GetOffset (sc.embedding_config .primary_emb_index ),
56- leveldb_, alloc_,
57- sc.embedding_config .slot_num + 1 );
58- } else {
59- dram_policy_ = nullptr ;
60- leveldb_policy_ = nullptr ;
61- }
62-
63- MultiTierStorage<K, V>::kvs_.emplace_back (
64- KVInterfaceDescriptor<K, V>(dram_kv_, alloc_, dram_mu_, dram_policy_));
65- MultiTierStorage<K, V>::kvs_.emplace_back (
66- KVInterfaceDescriptor<K, V>(leveldb_, alloc_,
67- leveldb_mu_, leveldb_policy_));
36+ : MultiTierStorage<K, V>(sc, name) {
37+ dram_ = new DramStorage<K, V>(sc, alloc, lc, new LocklessHashMap<K, V>());
38+ leveldb_ = new LevelDBStore<K, V>(sc, alloc, lc);
6839 }
6940
7041 ~DramLevelDBStore () override {
71- MultiTierStorage<K, V>::ReleaseValues (
72- {std::make_pair (dram_kv_, alloc_)});
73- if (dram_policy_ != nullptr ) delete dram_policy_;
74- if (leveldb_policy_ != nullptr ) delete leveldb_policy_;
42+ MultiTierStorage<K, V>::DeleteFromEvictionManager ();
43+ delete dram_;
44+ delete leveldb_;
7545 }
7646
7747 TF_DISALLOW_COPY_AND_ASSIGN (DramLevelDBStore);
7848
7949 Status Get (K key, ValuePtr<V>** value_ptr) override {
80- Status s = dram_kv_-> Lookup (key, value_ptr);
50+ Status s = dram_-> Get (key, value_ptr);
8151 if (!s.ok ()) {
82- s = leveldb_->Lookup (key, value_ptr);
52+ s = leveldb_->Get (key, value_ptr);
8353 }
8454 return s;
8555 }
@@ -89,48 +59,38 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
8959 }
9060
9161 void Insert (K key, ValuePtr<V>** value_ptr,
92- int64 alloc_len) override {
93- do {
94- *value_ptr = layout_creator_->Create (alloc_, alloc_len);
95- Status s = dram_kv_->Insert (key, *value_ptr);
96- if (s.ok ()) {
97- break ;
98- } else {
99- (*value_ptr)->Destroy (alloc_);
100- delete *value_ptr;
101- }
102- } while (!(dram_kv_->Lookup (key, value_ptr)).ok ());
62+ size_t alloc_len) override {
63+ dram_->Insert (key, value_ptr, alloc_len);
10364 }
10465
10566 Status GetOrCreate (K key, ValuePtr<V>** value_ptr,
10667 size_t size, CopyBackFlag &need_copyback) override {
107- need_copyback = NOT_COPYBACK;
108- return GetOrCreate (key, value_ptr, size);
68+ LOG (FATAL)<<" GetOrCreate(K key, ValuePtr<V>** value_ptr, "
69+ <<" size_t size, CopyBackFlag &need_copyback) "
70+ <<" in DramLevelDBStore can not be called." ;
10971 }
11072
11173 Status GetOrCreate (K key, ValuePtr<V>** value_ptr,
11274 size_t size) override {
113- Status s = dram_kv_-> Lookup (key, value_ptr);
75+ Status s = dram_-> Get (key, value_ptr);
11476 if (s.ok ()) {
11577 return s;
11678 }
117- s = leveldb_->Lookup (key, value_ptr);
118- if (!s.ok ()) {
119- *value_ptr = layout_creator_->Create (alloc_, size);
120- }
121-
122- s = dram_kv_->Insert (key, *value_ptr);
79+ s = leveldb_->Get (key, value_ptr);
12380 if (s.ok ()) {
124- return s;
81+ s = dram_->TryInsert (key, *value_ptr);
82+ if (s.ok ()) {
83+ return s;
84+ }
85+ leveldb_->DestroyValuePtr (*value_ptr);
86+ return dram_->Get (key, value_ptr);
12587 }
126- // Insert Failed, key already exist
127- (*value_ptr)->Destroy (alloc_);
128- delete *value_ptr;
129- return dram_kv_->Lookup (key, value_ptr);
88+ dram_->Insert (key, value_ptr, size);
89+ return Status::OK ();
13090 }
13191
13292 Status Remove (K key) override {
133- dram_kv_ ->Remove (key);
93+ dram_ ->Remove (key);
13494 leveldb_->Remove (key);
13595 return Status::OK ();
13696 }
@@ -150,46 +110,122 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
150110 }
151111
152112 void iterator_mutex_lock () override {
153- leveldb_mu_. lock ();
113+ leveldb_-> get_mutex ()-> lock ();
154114 }
155115
156116 void iterator_mutex_unlock () override {
157- leveldb_mu_. unlock ();
117+ leveldb_-> get_mutex ()-> unlock ();
158118 }
159119
160120 int64 Size () const override {
161- int64 total_size = dram_kv_ ->Size ();
121+ int64 total_size = dram_ ->Size ();
162122 total_size += leveldb_->Size ();
163123 return total_size;
164124 }
165125
126+ int64 Size (int level) const override {
127+ if (level == 0 ) {
128+ return dram_->Size ();
129+ } else if (level == 1 ) {
130+ return leveldb_->Size ();
131+ } else {
132+ return -1 ;
133+ }
134+ }
135+
136+ int LookupTier (K key) const override {
137+ Status s = dram_->Contains (key);
138+ if (s.ok ())
139+ return 0 ;
140+ s = leveldb_->Contains (key);
141+ if (s.ok ())
142+ return 1 ;
143+ return -1 ;
144+ }
145+
166146 Status GetSnapshot (std::vector<K>* key_list,
167147 std::vector<ValuePtr<V>*>* value_ptr_list) override {
168148 {
169- mutex_lock l (dram_mu_ );
170- TF_CHECK_OK (dram_kv_ ->GetSnapshot (key_list, value_ptr_list));
149+ mutex_lock l (*(dram_-> get_mutex ()) );
150+ TF_CHECK_OK (dram_ ->GetSnapshot (key_list, value_ptr_list));
171151 }
172152 {
173- mutex_lock l (leveldb_mu_ );
153+ mutex_lock l (*(leveldb_-> get_mutex ()) );
174154 TF_CHECK_OK (leveldb_->GetSnapshot (key_list, value_ptr_list));
175155 }
176156 return Status::OK ();
177157 }
178158
159+ Status Shrink (int64 value_len) override {
160+ dram_->Shrink (value_len);
161+ leveldb_->Shrink (value_len);
162+ return Status::OK ();
163+ }
164+
165+ Status Shrink (int64 global_step, int64 steps_to_live) override {
166+ dram_->Shrink (global_step, steps_to_live);
167+ leveldb_->Shrink (global_step, steps_to_live);
168+ return Status::OK ();
169+ }
170+
171+ int64 GetSnapshot (std::vector<K>* key_list,
172+ std::vector<V* >* value_list,
173+ std::vector<int64>* version_list,
174+ std::vector<int64>* freq_list,
175+ const EmbeddingConfig& emb_config,
176+ FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
177+ embedding::Iterator** it) override {
178+ {
179+ mutex_lock l (*(dram_->get_mutex ()));
180+ std::vector<ValuePtr<V>*> value_ptr_list;
181+ std::vector<K> key_list_tmp;
182+ TF_CHECK_OK (dram_->GetSnapshot (&key_list_tmp, &value_ptr_list));
183+ MultiTierStorage<K, V>::SetListsForCheckpoint (
184+ key_list_tmp, value_ptr_list, emb_config,
185+ key_list, value_list, version_list, freq_list);
186+ }
187+ {
188+ mutex_lock l (*(leveldb_->get_mutex ()));
189+ *it = leveldb_->GetIterator ();
190+ }
191+ return key_list->size ();
192+ }
193+
194+ Status Eviction (K* evict_ids, int64 evict_size) override {
195+ ValuePtr<V>* value_ptr;
196+ for (int64 i = 0 ; i < evict_size; ++i) {
197+ if (dram_->Get (evict_ids[i], &value_ptr).ok ()) {
198+ TF_CHECK_OK (leveldb_->Commit (evict_ids[i], value_ptr));
199+ TF_CHECK_OK (dram_->Remove (evict_ids[i]));
200+ dram_->DestroyValuePtr (value_ptr);
201+ }
202+ }
203+ return Status::OK ();
204+ }
205+
206+ Status EvictionWithDelayedDestroy (K* evict_ids, int64 evict_size) override {
207+ mutex_lock l (*(dram_->get_mutex ()));
208+ mutex_lock l1 (*(leveldb_->get_mutex ()));
209+ MultiTierStorage<K, V>::ReleaseInvalidValuePtr (dram_->alloc_ );
210+ ValuePtr<V>* value_ptr = nullptr ;
211+ for (int64 i = 0 ; i < evict_size; ++i) {
212+ if (dram_->Get (evict_ids[i], &value_ptr).ok ()) {
213+ TF_CHECK_OK (leveldb_->Commit (evict_ids[i], value_ptr));
214+ TF_CHECK_OK (dram_->Remove (evict_ids[i]));
215+ MultiTierStorage<K, V>::KeepInvalidValuePtr (value_ptr);
216+ }
217+ }
218+ return Status::OK ();
219+ }
220+
179221 protected:
180222 void SetTotalDims (int64 total_dims) override {
181223 leveldb_->SetTotalDims (total_dims);
182224 }
183225
184226 private:
185- KVInterface<K, V>* dram_kv_;
186- KVInterface<K, V>* leveldb_;
187- Allocator* alloc_;
188- ShrinkPolicy<K, V>* dram_policy_;
189- ShrinkPolicy<K, V>* leveldb_policy_;
190- LayoutCreator<V>* layout_creator_;
191- mutex dram_mu_; // must be locked before leveldb_mu_ is locked
192- mutex leveldb_mu_;
227+ DramStorage<K, V>* dram_;
228+ LevelDBStore<K, V>* leveldb_;
193229};
194230} // embedding
195231} // tensorflow
0 commit comments