Skip to content

Commit 8afcfff

Browse files
authored
[Embedding] Refactor embedding storage code by cleaning up useless StorageManager class. (#858)
Signed-off-by: lixy9474 <[email protected]>
1 parent d350de3 commit 8afcfff

17 files changed

+316
-625
lines changed

tensorflow/core/framework/embedding/bloom_filter_policy.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ limitations under the License.
2020
#include "tensorflow/core/framework/embedding/filter_policy.h"
2121

2222
namespace tensorflow {
23-
namespace embedding{
24-
template <class K, class V>
25-
class StorageManager;
26-
}
2723

2824
namespace {
2925
const static std::vector<int64> default_seeds = {
@@ -35,9 +31,8 @@ const static std::vector<int64> default_seeds = {
3531
template<typename K, typename V, typename EV>
3632
class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
3733
public:
38-
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev,
39-
embedding::StorageManager<K, V>* storage_manager) :
40-
config_(config), ev_(ev), storage_manager_(storage_manager) {
34+
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev)
35+
: config_(config), ev_(ev) {
4136
switch (config_.counter_type){
4237
case DT_UINT64:
4338
VLOG(2) << "The type of bloom counter is uint64";
@@ -349,7 +344,6 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
349344
EmbeddingConfig config_;
350345
EV* ev_;
351346
std::vector<int64> seeds_;
352-
embedding::StorageManager<K, V>* storage_manager_;
353347
};
354348
} // tensorflow
355349

tensorflow/core/framework/embedding/cache_factory.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ class CacheFactory {
2525
static BatchCache<K>* Create(CacheStrategy cache_strategy, std::string name) {
2626
switch (cache_strategy) {
2727
case CacheStrategy::LRU:
28-
LOG(INFO) << " Use StorageManager::LRU in multi-tier EmbeddingVariable "
28+
LOG(INFO) << " Use Storage::LRU in multi-tier EmbeddingVariable "
2929
<< name;
3030
return new LRUCache<K>();
3131
case CacheStrategy::LFU:
32-
LOG(INFO) << " Use StorageManager::LFU in multi-tier EmbeddingVariable "
32+
LOG(INFO) << " Use Storage::LFU in multi-tier EmbeddingVariable "
3333
<< name;
3434
return new LFUCache<K>();
3535
default:

tensorflow/core/framework/embedding/counter_filter_policy.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,12 @@ limitations under the License.
2020
#include "tensorflow/core/framework/embedding/filter_policy.h"
2121

2222
namespace tensorflow {
23-
namespace embedding{
24-
template <class K, class V>
25-
class StorageManager;
26-
}
2723

2824
template<typename K, typename V, typename EV>
2925
class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
3026
public:
31-
CounterFilterPolicy(const EmbeddingConfig& config,
32-
EV* ev, embedding::StorageManager<K, V>* storage_manager)
33-
: config_(config), ev_(ev), storage_manager_(storage_manager) {
27+
CounterFilterPolicy(const EmbeddingConfig& config, EV* ev)
28+
: config_(config), ev_(ev){
3429
}
3530

3631
Status Lookup(EV* ev, K key, V* val, const V* default_value_ptr,
@@ -165,7 +160,6 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
165160

166161
private:
167162
EmbeddingConfig config_;
168-
embedding::StorageManager<K, V>* storage_manager_;
169163
EV* ev_;
170164
};
171165

tensorflow/core/framework/embedding/embedding_var.cu.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void EmbeddingVar<K, V>::SetDefaultValueOfNewFeatures(
6666
reinterpret_cast<ValuePtr<V>*>(memcpy_address[*it]);
6767
value_address[i] =
6868
*((V**)((char*)(value_ptr->GetPtr()) + sizeof(FixedLengthHeader))) +
69-
storage_manager_->GetOffset(emb_config_.emb_index);
69+
storage_->GetOffset(emb_config_.emb_index);
7070
default_value_address[i] = get_default_v_fn(
7171
default_values, keys[*it], *it, GetDefaultValueDim(), ValueLen());
7272
}
@@ -86,7 +86,7 @@ void EmbeddingVar<K, V>::SetDefaultValueOfNewFeatures(
8686
value_ptr->SetInitialized(emb_config_.emb_index);
8787
memcpy_address[*it] = value_ptr->GetValue(
8888
emb_config_.emb_index,
89-
storage_manager_->GetOffset(emb_config_.emb_index));
89+
storage_->GetOffset(emb_config_.emb_index));
9090
}
9191
TypedAllocator::Deallocate(alloc_, dev_value_address, total * 2);
9292
TypedAllocator::Deallocate(cpu_allocator(), value_address, total * 2);
@@ -148,12 +148,12 @@ void EmbeddingVar<K, V>::CopyEmbeddingsFromCPUToGPU(
148148
int64* output_value_ptrs) {
149149
if (copyback_cursor.size() > 0) {
150150
int64 total = copyback_cursor.size();
151-
size_t value_len = emb_config_.total_num(storage_manager_->GetAllocLen());
151+
size_t value_len = emb_config_.total_num(storage_->GetAllocLen());
152152
V* memcpy_buffer_gpu = nullptr;
153153
ValuePtr<V>** gpu_value_ptrs = new ValuePtr<V>*[total];
154154
memcpy_buffer_gpu = (V*)alloc_->AllocateRaw(Allocator::kAllocatorAlignment,
155155
total * value_len * sizeof(V));
156-
storage_manager_->CopyEmbeddingsFromCPUToGPU(
156+
storage_->CopyEmbeddingsFromCPUToGPU(
157157
total, keys, copyback_cursor, memcpy_address, value_len, gpu_value_ptrs,
158158
memcpy_buffer_gpu, compute_stream, event_mgr, worker_threads);
159159

@@ -185,7 +185,7 @@ void EmbeddingVar<K, V>::CopyEmbeddingsFromCPUToGPU(
185185
auto do_insert = [this, copyback_keys, gpu_value_ptrs, value_len](
186186
int64 start, int64 limit) {
187187
for (int64 i = start; i < limit; i++)
188-
storage_manager_->Insert(copyback_keys[i], gpu_value_ptrs[i]);
188+
storage_->Insert(copyback_keys[i], gpu_value_ptrs[i]);
189189
};
190190
Shard(worker_threads->num_threads, worker_threads->workers,
191191
copyback_keys.size(), 100000, do_insert);

0 commit comments

Comments
 (0)