Skip to content

Commit f09e5ec

Browse files
authored
[Embedding] Add GetSnapshot and Create API for EmbeddingVariable. (#923)
Signed-off-by: lixy9474 <[email protected]>
1 parent 4983e02 commit f09e5ec

File tree

3 files changed

+88
-6
lines changed

3 files changed

+88
-6
lines changed

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,13 @@ class EmbeddingVar : public ResourceBase {
186186
}
187187
}
188188

189+
Status Insert(K key, V* value) {
190+
ValuePtr<V>* value_ptr = nullptr;
191+
CreateKey(key, &value_ptr, true);
192+
LookupOrCreateEmb(value_ptr, value);
193+
return Status::OK();
194+
}
195+
189196
Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr) {
190197
Status s = storage_->GetOrCreate(key, value_ptr,
191198
emb_config_.total_num(storage_->GetAllocLen()));
@@ -592,6 +599,34 @@ class EmbeddingVar : public ResourceBase {
592599
default_value_);
593600
}
594601

602+
void GetSnapshot(std::vector<K>* key_list,
603+
std::vector<V*>* value_list,
604+
std::vector<int64>* version_list,
605+
std::vector<int64>* freq_list) {
606+
std::vector<ValuePtr<V>*> value_ptr_list;
607+
storage_->GetSnapshot(key_list, &value_ptr_list);
608+
bool is_save_freq = emb_config_.is_save_freq();
609+
bool is_save_version = emb_config_.is_save_version();
610+
for (int64 i = 0; i < key_list->size(); i++) {
611+
V* val = value_ptr_list[i]->GetValue(emb_config_.emb_index, 0);
612+
if (val != nullptr) {
613+
value_list->emplace_back(val);
614+
} else {
615+
value_list->emplace_back(default_value_);
616+
}
617+
618+
if(is_save_version) {
619+
int64 dump_version = value_ptr_list[i]->GetStep();
620+
version_list->emplace_back(dump_version);
621+
}
622+
623+
if(is_save_freq) {
624+
int64 dump_freq = value_ptr_list[i]->GetFreq();
625+
freq_list->emplace_back(dump_freq);
626+
}
627+
}
628+
}
629+
595630
mutex* mu() {
596631
return &mu_;
597632
}

tensorflow/core/framework/embedding/eviction_manager.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ class EvictionManager {
4747
"EVICTION_MANAGER", 3, /*low_latency_hint=*/false));
4848
}
4949

50-
~EvictionManager() {
51-
}
50+
~EvictionManager() {}
5251

5352
TF_DISALLOW_COPY_AND_ASSIGN(EvictionManager);
5453

@@ -124,8 +123,8 @@ class EvictionManager {
124123
int64 num_of_threads_;
125124
int64 num_of_active_threads_;
126125
std::atomic_flag flag_ = ATOMIC_FLAG_INIT;
127-
std::unique_ptr<thread::ThreadPool> thread_pool_;
128126
std::map<MultiTierStorage<K,V>*, StorageItem<K, V>*> storage_table_;
127+
std::unique_ptr<thread::ThreadPool> thread_pool_;
129128
mutex mu_;
130129
};
131130

tensorflow/core/kernels/embedding_variable_ops_test.cc

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ TEST(EmbeddingVariableTest, TestLFUCache) {
11911191
}
11921192

11931193
TEST(EmbeddingVariableTest, TestCacheRestore) {
1194+
setenv("TF_SSDHASH_ASYNC_COMPACTION", "false", 1);
11941195
int64 value_size = 4;
11951196
Tensor value(DT_FLOAT, TensorShape({value_size}));
11961197
test::FillValues<float>(&value, std::vector<float>(value_size, 9.0));
@@ -1237,8 +1238,11 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {
12371238
LOG(INFO) << "size:" << variable->Size();
12381239

12391240
BundleWriter writer(Env::Default(), Prefix("foo"));
1240-
DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor);
1241-
TF_ASSERT_OK(writer.Finish());
1241+
embedding::ShrinkArgs shrink_args;
1242+
shrink_args.global_step = 1;
1243+
variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args);
1244+
TF_ASSERT_OK(writer.Finish());
1245+
variable->Unref();
12421246

12431247
auto imported_storage= embedding::StorageFactory::Create<int64, float>(
12441248
embedding::StorageConfig(embedding::DRAM_SSDHASH,
@@ -1258,6 +1262,7 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {
12581262

12591263
ASSERT_EQ(imported_storage->Size(0), ev_size - cache_size);
12601264
ASSERT_EQ(imported_storage->Size(1), 2);
1265+
delete imported_storage;
12611266
}
12621267

12631268
void t1_gpu(KVInterface<int64, float>* hashmap) {
@@ -1703,7 +1708,50 @@ TEST(EmbeddingVariableTest, TestLookupRemoveConcurrency) {
17031708
for (auto &t : insert_threads) {
17041709
t.join();
17051710
}
1706-
}
1711+
}
1712+
1713+
TEST(EmbeddingVariableTest, TestInsertAndGetSnapshot) {
1714+
int value_size = 10;
1715+
Tensor value(DT_FLOAT, TensorShape({value_size}));
1716+
test::FillValues<float>(&value, std::vector<float>(value_size, 10.0));
1717+
auto emb_config = EmbeddingConfig(
1718+
/*emb_index = */0, /*primary_emb_index = */0,
1719+
/*block_num = */1, /*slot_num = */0,
1720+
/*name = */"", /*steps_to_live = */0,
1721+
/*filter_freq = */0, /*max_freq = */999999,
1722+
/*l2_weight_threshold = */-1.0, /*layout = */"normal",
1723+
/*max_element_size = */0, /*false_positive_probability = */-1.0,
1724+
/*counter_type = */DT_UINT64);
1725+
auto storage = embedding::StorageFactory::Create<int64, float>(
1726+
embedding::StorageConfig(), cpu_allocator(), "EmbeddingVar");
1727+
auto var = new EmbeddingVar<int64, float>("EmbeddingVar",
1728+
storage,
1729+
emb_config,
1730+
cpu_allocator());
1731+
var->Init(value, 1);
1732+
float* set_value = (float*)malloc(value_size * sizeof(float));
1733+
//Insertion
1734+
for (int i = 0; i < 100; i++) {
1735+
for (int j = 0; j < value_size; j++) {
1736+
set_value[j] = i + j;
1737+
}
1738+
var->Insert(i, set_value);
1739+
}
1740+
free(set_value);
1741+
//GetSnapshot
1742+
std::vector<int64> key_list;
1743+
std::vector<float*> value_ptr_list;
1744+
std::vector<int64> version_list;
1745+
std::vector<int64> freq_list;
1746+
var->GetSnapshot(&key_list, &value_ptr_list,
1747+
&version_list, &freq_list);
1748+
for (int i = 0; i < key_list.size(); i++) {
1749+
ASSERT_EQ(key_list[i], i);
1750+
for (int j = 0; j < value_size; j++) {
1751+
ASSERT_EQ(value_ptr_list[i][j], i + j);
1752+
}
1753+
}
1754+
}
17071755

17081756
} // namespace
17091757
} // namespace embedding

0 commit comments

Comments
 (0)