@@ -1191,6 +1191,7 @@ TEST(EmbeddingVariableTest, TestLFUCache) {
11911191}
11921192
11931193TEST (EmbeddingVariableTest, TestCacheRestore) {
1194+ setenv (" TF_SSDHASH_ASYNC_COMPACTION" , " false" , 1 );
11941195 int64 value_size = 4 ;
11951196 Tensor value (DT_FLOAT, TensorShape ({value_size}));
11961197 test::FillValues<float >(&value, std::vector<float >(value_size, 9.0 ));
@@ -1237,8 +1238,11 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {
12371238 LOG (INFO) << " size:" << variable->Size ();
12381239
12391240 BundleWriter writer (Env::Default (), Prefix (" foo" ));
1240- DumpEmbeddingValues (variable, " var/part_0" , &writer, &part_offset_tensor);
1241- TF_ASSERT_OK (writer.Finish ());
1241+ embedding::ShrinkArgs shrink_args;
1242+ shrink_args.global_step = 1 ;
1243+ variable->Save (" var/part_0" , Prefix (" foo" ), &writer, shrink_args);
1244+ TF_ASSERT_OK (writer.Finish ());
1245+ variable->Unref ();
12421246
12431247 auto imported_storage= embedding::StorageFactory::Create<int64, float >(
12441248 embedding::StorageConfig (embedding::DRAM_SSDHASH,
@@ -1258,6 +1262,7 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {
12581262
12591263 ASSERT_EQ (imported_storage->Size (0 ), ev_size - cache_size);
12601264 ASSERT_EQ (imported_storage->Size (1 ), 2 );
1265+ delete imported_storage;
12611266}
12621267
12631268void t1_gpu (KVInterface<int64, float >* hashmap) {
@@ -1703,7 +1708,50 @@ TEST(EmbeddingVariableTest, TestLookupRemoveConcurrency) {
17031708 for (auto &t : insert_threads) {
17041709 t.join ();
17051710 }
1706- }
1711+ }
1712+
1713+ TEST (EmbeddingVariableTest, TestInsertAndGetSnapshot) {
1714+ int value_size = 10 ;
1715+ Tensor value (DT_FLOAT, TensorShape ({value_size}));
1716+ test::FillValues<float >(&value, std::vector<float >(value_size, 10.0 ));
1717+ auto emb_config = EmbeddingConfig (
1718+ /* emb_index = */ 0 , /* primary_emb_index = */ 0 ,
1719+ /* block_num = */ 1 , /* slot_num = */ 0 ,
1720+ /* name = */ " " , /* steps_to_live = */ 0 ,
1721+ /* filter_freq = */ 0 , /* max_freq = */ 999999 ,
1722+ /* l2_weight_threshold = */ -1.0 , /* layout = */ " normal" ,
1723+ /* max_element_size = */ 0 , /* false_positive_probability = */ -1.0 ,
1724+ /* counter_type = */ DT_UINT64);
1725+ auto storage = embedding::StorageFactory::Create<int64, float >(
1726+ embedding::StorageConfig (), cpu_allocator (), " EmbeddingVar" );
1727+ auto var = new EmbeddingVar<int64, float >(" EmbeddingVar" ,
1728+ storage,
1729+ emb_config,
1730+ cpu_allocator ());
1731+ var->Init (value, 1 );
1732+ float * set_value = (float *)malloc (value_size * sizeof (float ));
1733+ // Insertion
1734+ for (int i = 0 ; i < 100 ; i++) {
1735+ for (int j = 0 ; j < value_size; j++) {
1736+ set_value[j] = i + j;
1737+ }
1738+ var->Insert (i, set_value);
1739+ }
1740+ free (set_value);
1741+ // GetSnapshot
1742+ std::vector<int64> key_list;
1743+ std::vector<float *> value_ptr_list;
1744+ std::vector<int64> version_list;
1745+ std::vector<int64> freq_list;
1746+ var->GetSnapshot (&key_list, &value_ptr_list,
1747+ &version_list, &freq_list);
1748+ for (int i = 0 ; i < key_list.size (); i++) {
1749+ ASSERT_EQ (key_list[i], i);
1750+ for (int j = 0 ; j < value_size; j++) {
1751+ ASSERT_EQ (value_ptr_list[i][j], i + j);
1752+ }
1753+ }
1754+ }
17071755
17081756} // namespace
17091757} // namespace embedding
0 commit comments