@@ -35,9 +35,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
3535 using FilterPolicy<K, V, EV>::config_;
3636
3737 public:
38- BloomFilterPolicy (const EmbeddingConfig& config, EV* ev) :
39- FilterPolicy<K, V, EV>(config, ev) {
40-
38+ BloomFilterPolicy (const EmbeddingConfig& config, EV* ev,
39+ embedding::FeatureDescriptor<V>* feat_desc)
40+ : feat_desc_(feat_desc),
41+ FilterPolicy<K, V, EV>(config, ev) {
4142 switch (config_.counter_type ){
4243 case DT_UINT64:
4344 VLOG (2 ) << " The type of bloom counter is uint64" ;
@@ -64,10 +65,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
6465
6566 Status Lookup (K key, V* val, const V* default_value_ptr,
6667 const V* default_value_no_permission) override {
67- ValuePtr<V> * value_ptr = nullptr ;
68+ void * value_ptr = nullptr ;
6869 Status s = ev_->LookupKey (key, &value_ptr);
6970 if (s.ok ()) {
70- V* mem_val = ev_-> LookupOrCreateEmb (value_ptr, default_value_ptr );
71+ V* mem_val = feat_desc_-> GetEmbedding (value_ptr, config_. emb_index );
7172 memcpy (val, mem_val, sizeof (V) * ev_->ValueLen ());
7273 } else {
7374 memcpy (val, default_value_no_permission, sizeof (V) * ev_->ValueLen ());
@@ -81,17 +82,17 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
8182 int64 num_of_keys,
8283 V* default_value_ptr,
8384 V* default_value_no_permission) override {
84- std::vector<ValuePtr<V> *> value_ptr_list (num_of_keys, nullptr );
85+ std::vector<void *> value_ptr_list (num_of_keys, nullptr );
8586 ev_->BatchLookupKey (ctx, keys, value_ptr_list.data (), num_of_keys);
8687 std::vector<V*> embedding_ptr (num_of_keys, nullptr );
8788 auto do_work = [this , value_ptr_list, &embedding_ptr,
8889 default_value_ptr, default_value_no_permission]
8990 (int64 start, int64 limit) {
9091 for (int i = start; i < limit; i++) {
91- ValuePtr<V> * value_ptr = value_ptr_list[i];
92+ void * value_ptr = value_ptr_list[i];
9293 if (value_ptr != nullptr ) {
9394 embedding_ptr[i] =
94- ev_-> LookupOrCreateEmb (value_ptr, default_value_ptr );
95+ feat_desc_-> GetEmbedding (value_ptr, config_. emb_index );
9596 } else {
9697 embedding_ptr[i] = default_value_no_permission;
9798 }
@@ -109,13 +110,13 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
109110 }
110111
111112 void BatchLookupOrCreateKey (const EmbeddingVarContext<GPUDevice>& ctx,
112- const K* keys, ValuePtr<V> ** value_ptrs_list,
113+ const K* keys, void ** value_ptrs_list,
113114 int64 num_of_keys) {
114115 int num_worker_threads = ctx.worker_threads ->num_threads ;
115116 std::vector<std::vector<K>> lookup_or_create_ids (num_worker_threads);
116117 std::vector<std::vector<int >>
117118 lookup_or_create_cursor (num_worker_threads);
118- std::vector<std::vector<ValuePtr<V> *>>
119+ std::vector<std::vector<void *>>
119120 lookup_or_create_ptrs (num_worker_threads);
120121 IntraThreadCopyIdAllocator thread_copy_id_alloc (num_worker_threads);
121122 std::vector<std::list<int64>>
@@ -147,7 +148,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
147148 1000 , do_work);
148149
149150 std::vector<K> total_ids (num_of_keys);
150- std::vector<ValuePtr<V> *> total_ptrs (num_of_keys);
151+ std::vector<void *> total_ptrs (num_of_keys);
151152 std::vector<int > total_cursors (num_of_keys);
152153 int num_of_admit_id = 0 ;
153154 for (int i = 0 ; i < num_worker_threads; i++) {
@@ -157,7 +158,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
157158 sizeof (K) * lookup_or_create_ids[i].size ());
158159 memcpy (total_ptrs.data () + num_of_admit_id,
159160 lookup_or_create_ptrs[i].data (),
160- sizeof (ValuePtr<V> *) * lookup_or_create_ptrs[i].size ());
161+ sizeof (void *) * lookup_or_create_ptrs[i].size ());
161162 memcpy (total_cursors.data () + num_of_admit_id,
162163 lookup_or_create_cursor[i].data (),
163164 sizeof (int ) * lookup_or_create_cursor[i].size ());
@@ -174,31 +175,40 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
174175#endif // GOOGLE_CUDA
175176
176177 void LookupOrCreate (K key, V* val, const V* default_value_ptr,
177- ValuePtr<V> ** value_ptr, int count,
178+ void ** value_ptr, int count,
178179 const V* default_value_no_permission) override {
179180 if (GetBloomFreq (key) >= config_.filter_freq ) {
180- TF_CHECK_OK (ev_->LookupOrCreateKey (key, value_ptr));
181- V* mem_val = ev_->LookupOrCreateEmb (*value_ptr, default_value_ptr);
181+ bool is_filter = true ;
182+ TF_CHECK_OK (LookupOrCreateKey (key, value_ptr, &is_filter, count));
183+ V* mem_val = feat_desc_->GetEmbedding (*value_ptr, config_.emb_index );
182184 memcpy (val, mem_val, sizeof (V) * ev_->ValueLen ());
183185 } else {
184186 AddFreq (key, count);
185187 memcpy (val, default_value_no_permission, sizeof (V) * ev_->ValueLen ());
186188 }
187189 }
188190
189- Status LookupOrCreateKey (K key, ValuePtr<V> ** val ,
191+ Status LookupOrCreateKey (K key, void ** value_ptr ,
190192 bool * is_filter, int64 count) override {
191- *val = nullptr ;
192- if ((GetFreq (key, *val) + count) >= config_.filter_freq ) {
193+ *value_ptr = nullptr ;
194+ if ((GetFreq (key, *value_ptr) + count) >= config_.filter_freq ) {
195+ Status s = ev_->LookupKey (key, value_ptr);
196+ if (!s.ok ()) {
197+ *value_ptr = feat_desc_->Allocate ();
198+ feat_desc_->SetDefaultValue (*value_ptr, key);
199+ ev_->storage ()->Insert (key, value_ptr);
200+ s = Status::OK ();
201+ }
193202 *is_filter = true ;
194- return ev_->LookupOrCreateKey (key, val);
203+ feat_desc_->AddFreq (*value_ptr, count);
204+ } else {
205+ *is_filter = false ;
206+ AddFreq (key, count);
195207 }
196- *is_filter = false ;
197- AddFreq (key, count);
198208 return Status::OK ();
199209 }
200210
201- int64 GetFreq (K key, ValuePtr<V>* ) override {
211+ int64 GetFreq (K key, void * val ) override {
202212 return GetBloomFreq (key);
203213 }
204214
@@ -210,7 +220,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
210220 return bloom_counter_;
211221 }
212222
213- bool is_admit (K key, ValuePtr<V> * value_ptr) override {
223+ bool is_admit (K key, void * value_ptr) override {
214224 if (value_ptr == nullptr ) {
215225 return false ;
216226 } else {
@@ -326,8 +336,12 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
326336 LOG (INFO) << " skip EV key:" << *(key_buff + i);
327337 continue ;
328338 }
329- ValuePtr<V> * value_ptr = nullptr ;
339+ void * value_ptr = nullptr ;
330340 int64 new_freq = freq_buff[i];
341+ int64 import_version = -1 ;
342+ if (config_.steps_to_live != 0 || config_.record_version ) {
343+ import_version = version_buff[i];
344+ }
331345 if (!is_filter) {
332346 if (freq_buff[i] >= config_.filter_freq ) {
333347 SetBloomFreq (key_buff[i], freq_buff[i]);
@@ -339,17 +353,9 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
339353 SetBloomFreq (key_buff[i], freq_buff[i]);
340354 }
341355 if (new_freq >= config_.filter_freq ){
342- ev_->CreateKey (key_buff[i], &value_ptr, to_dram);
343- if (config_.steps_to_live != 0 || config_.record_version ) {
344- value_ptr->SetStep (version_buff[i]);
345- }
346- if (!is_filter){
347- ev_->LookupOrCreateEmb (value_ptr,
348- value_buff + i * ev_->ValueLen ());
349- } else {
350- ev_->LookupOrCreateEmb (value_ptr,
351- ev_->GetDefaultValue (key_buff[i]));
352- }
356+ ev_->storage ()->Import (key_buff[i],
357+ value_buff + i * ev_->ValueLen (),
358+ new_freq, import_version, config_.emb_index );
353359 }
354360 }
355361 return Status::OK ();
@@ -449,6 +455,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
449455 }
450456 private:
451457 void * bloom_counter_;
458+ embedding::FeatureDescriptor<V>* feat_desc_;
452459 std::vector<int64> seeds_;
453460};
454461} // tensorflow
0 commit comments