Skip to content

Commit 2d31c8e

Browse files
authored
[Embedding] Add interface of EmbeddingVar for Elastic Training. (#933)
Signed-off-by: JunqiHu <[email protected]>
1 parent 0e8127a commit 2d31c8e

21 files changed

+244
-20
lines changed

configure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ def main():
14341434
True, 'star')
14351435

14361436
set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support',
1437-
True, 'elastic')
1437+
False, 'elastic')
14381438

14391439
set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support',
14401440
False, 'pmem')

tensorflow/contrib/elastic_grpc_server/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ cc_library(
5656
tf_cc_test(
5757
name = "elastic_grpc_test",
5858
size = "small",
59-
srcs = ["elastic_grpc_server_lib_test.cc"],
59+
srcs = select({"//tensorflow:with_elastic_support": ["elastic_grpc_server_lib_test.cc"],
60+
"//conditions:default": []}),
6061
deps = [
6162
":elastic_grpc_server_lib",
6263
"//tensorflow/core/distributed_runtime/rpc:grpc_util",

tensorflow/core/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ load(
128128
"tf_additional_numa_deps",
129129
"tf_additional_numa_lib_defines",
130130
"tf_additional_star_lib_defines",
131+
"tf_additional_elastic_server_lib_defines",
131132
"tf_additional_api_compatible_defines",
132133
"tf_additional_pmem_lib_defines",
133134
"tf_additional_test_deps",
@@ -1441,6 +1442,7 @@ tf_cc_test(
14411442
cc_library(
14421443
name = "ops",
14431444
visibility = ["//visibility:public"],
1445+
defines = tf_additional_elastic_server_lib_defines(),
14441446
deps = [
14451447
":array_ops_op_lib",
14461448
":parquet_ops_op_lib",
@@ -2562,7 +2564,8 @@ LIB_INTERNAL_DEFINES = (
25622564
tf_additional_gdr_lib_defines() +
25632565
tf_additional_numa_lib_defines() +
25642566
tf_additional_star_lib_defines() +
2565-
tf_additional_pmem_lib_defines()
2567+
tf_additional_pmem_lib_defines() +
2568+
tf_additional_elastic_server_lib_defines()
25662569
)
25672570

25682571
cc_library(

tensorflow/core/framework/embedding/bloom_filter_policy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
333333
// this can describe by graph(Mod + DynamicPartition),
334334
// but memory waste and slow
335335
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
336-
LOG(INFO) << "skip EV key:" << *(key_buff + i);
336+
VLOG(1) << "skip EV key:" << *(key_buff + i);
337337
continue;
338338
}
339339
void* value_ptr = nullptr;

tensorflow/core/framework/embedding/counter_filter_policy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
159159
// this can describe by graph(Mod + DynamicPartition),
160160
// but memory waste and slow
161161
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
162-
LOG(INFO) << "skip EV key:" << *(key_buff + i);
162+
VLOG(1) << "skip EV key:" << *(key_buff + i);
163163
continue;
164164
}
165165
int64 import_freq = 0;

tensorflow/core/framework/embedding/cpu_hash_map_kv.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,28 @@ class LocklessHashMap : public KVInterface<K, V> {
137137
return Status::OK();
138138
}
139139

140+
Status GetShardedSnapshot(
141+
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
142+
int partition_id, int partition_nums) override {
143+
std::pair<const K, void*> *hash_map_dump;
144+
int64 bucket_count;
145+
auto it = hash_map_.GetSnapshot();
146+
hash_map_dump = it.first;
147+
bucket_count = it.second;
148+
for (int64 j = 0; j < bucket_count; j++) {
149+
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
150+
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
151+
&& hash_map_dump[j].first % kSavedPartitionNum
152+
% partition_nums != partition_id) {
153+
key_list->emplace_back(hash_map_dump[j].first);
154+
value_ptr_list->emplace_back(hash_map_dump[j].second);
155+
}
156+
}
157+
158+
free(hash_map_dump);
159+
return Status::OK();
160+
}
161+
140162
std::string DebugString() const override {
141163
LOG(INFO) << "map info size:" << Size()
142164
<< "map info bucket_count:" << hash_map_.bucket_count()

tensorflow/core/framework/embedding/dense_hash_map_kv.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,25 @@ class DenseHashMap : public KVInterface<K, V> {
121121
return Status::OK();
122122
}
123123

124+
Status GetShardedSnapshot(
125+
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
126+
int partition_id, int partition_nums) override {
127+
dense_hash_map hash_map_dump[partition_num_];
128+
for (int i = 0; i< partition_num_; i++) {
129+
spin_rd_lock l(hash_map_[i].mu);
130+
hash_map_dump[i].hash_map = hash_map_[i].hash_map;
131+
}
132+
for (int i = 0; i< partition_num_; i++) {
133+
for (const auto it : hash_map_dump[i].hash_map) {
134+
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
135+
key_list->push_back(it.first);
136+
value_ptr_list->push_back(it.second);
137+
}
138+
}
139+
}
140+
return Status::OK();
141+
}
142+
124143
std::string DebugString() const override {
125144
return "";
126145
}

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ class EmbeddingVar : public ResourceBase {
435435
return storage_->CacheSize();
436436
}
437437

438+
int64 MemoryUsage() const {
439+
return storage_->Size() * (sizeof(K) + feat_desc_->data_bytes());
440+
}
441+
438442
int64 MinFreq() {
439443
return emb_config_.filter_freq;
440444
}
@@ -516,6 +520,85 @@ class EmbeddingVar : public ResourceBase {
516520
}
517521
}
518522

523+
Status GetShardedSnapshot(std::vector<K>* key_list,
524+
std::vector<void*>* value_ptr_list,
525+
int partition_id, int partition_num) {
526+
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
527+
partition_id, partition_num);
528+
}
529+
530+
void ExportAndRemove(K* key_list, V* value_list,
531+
int64* version_list, int64* freq_list,
532+
std::vector<K>& tot_keys_list,
533+
std::vector<void*>& tot_value_ptr_list) {
534+
bool save_unfiltered_features = true;
535+
TF_CHECK_OK(ReadBoolFromEnvVar(
536+
"TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features));
537+
538+
bool is_save_freq = emb_config_.is_save_freq();
539+
bool is_save_version = emb_config_.is_save_version();
540+
541+
for (int64 i = 0; i < tot_keys_list.size(); ++i) {
542+
auto& value_ptr = tot_value_ptr_list[i];
543+
if((int64)value_ptr == embedding::ValuePtrStatus::IS_DELETED)
544+
continue;
545+
546+
bool is_admit = feat_desc_->IsAdmit(value_ptr);
547+
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);
548+
549+
if (!is_admit) {
550+
key_list[i] = tot_keys_list[i];
551+
552+
if (!is_in_dram) {
553+
auto tmp_value = value_list + i * value_len_;
554+
tmp_value = (V*)embedding::ValuePtrStatus::NOT_IN_DRAM;
555+
value_ptr = (void*)((int64)value_ptr & ((1L << kDramFlagOffset) - 1));
556+
} else if (feat_desc_->GetEmbedding(value_ptr, 0) == nullptr) {
557+
memcpy(value_list + i * value_len_, default_value_, sizeof(V) * value_len_);
558+
} else {
559+
V* val = feat_desc_->GetEmbedding(value_ptr, emb_config_.emb_index);
560+
memcpy(value_list + i * value_len_, val, sizeof(V) * value_len_);
561+
}
562+
563+
if(is_save_version) {
564+
int64 dump_version = feat_desc_->GetVersion(value_ptr);
565+
version_list[i] = dump_version;
566+
}
567+
568+
if(is_save_freq) {
569+
int64 dump_freq = feat_desc_->GetFreq(value_ptr);
570+
freq_list[i] = dump_freq;
571+
}
572+
} else {
573+
if (!save_unfiltered_features)
574+
return;
575+
//TODO(JUNQI) : currently not export filtered keys
576+
}
577+
578+
if (emb_config_.is_primary()) {
579+
Status s;
580+
s = storage_->Remove(tot_keys_list[i]);
581+
if (!s.ok()) {
582+
LOG(ERROR) << "Remove keys error: " << s.error_message();
583+
}
584+
feat_desc_->Deallocate(value_ptr);
585+
}
586+
}
587+
}
588+
589+
Status RestoreFromKeysAndValues(int64 key_num, int partition_id,
590+
int partition_num, const K* key_list,
591+
const V* value_list, const int64* version_list,
592+
const int64* freq_list,
593+
const Eigen::GpuDevice* device = nullptr) {
594+
RestoreBuffer restore_buff((char*)key_list, (char*)value_list,
595+
(char*)version_list, (char*)freq_list);
596+
return storage_->RestoreFeatures(key_num, kSavedPartitionNum,
597+
partition_id, partition_num,
598+
value_len_, false/* is_filter*/, false/* is_incr*/,
599+
emb_config_, device, filter_, restore_buff);
600+
}
601+
519602
mutex* mu() {
520603
return &mu_;
521604
}
@@ -537,6 +620,8 @@ class EmbeddingVar : public ResourceBase {
537620
}
538621
}
539622

623+
string Name() {return name_; }
624+
540625
V* GetDefaultValuePtr() {
541626
return default_value_;
542627
}
@@ -645,7 +730,6 @@ class EmbeddingVar : public ResourceBase {
645730
GPUHashTable<K, V>* HashTable() {
646731
return storage_->HashTable();
647732
}
648-
649733
FilterPolicy<K, V, EmbeddingVar<K, V>>* GetFilter() const {
650734
return filter_;
651735
}

tensorflow/core/framework/embedding/embedding_var_ckpt_data.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ limitations under the License.
2020
namespace tensorflow {
2121
class BundleWriter;
2222
namespace {
23-
const int kSavedPartitionNum = 1000;
2423
const int kDramFlagOffset = 49;
2524
}
2625

tensorflow/core/framework/embedding/filter_policy.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,31 @@ struct RestoreBuffer {
2727
char* value_buffer = nullptr;
2828
char* version_buffer = nullptr;
2929
char* freq_buffer = nullptr;
30+
bool should_release = false;
3031

3132
explicit RestoreBuffer(size_t buffer_size) {
3233
key_buffer = new char[buffer_size];
3334
value_buffer = new char[buffer_size];
3435
version_buffer = new char[buffer_size];
3536
freq_buffer = new char[buffer_size];
37+
should_release = true;
38+
}
39+
40+
explicit RestoreBuffer(char* i_key_buffer, char* i_value_buffer,
41+
char* i_version_buffer, char* i_freq_buffer) {
42+
key_buffer = i_key_buffer;
43+
value_buffer = i_value_buffer;
44+
version_buffer = i_version_buffer;
45+
freq_buffer = i_freq_buffer;
3646
}
3747

3848
~RestoreBuffer() {
39-
delete []key_buffer;
40-
delete []value_buffer;
41-
delete []version_buffer;
42-
delete []freq_buffer;
49+
if (should_release) {
50+
delete []key_buffer;
51+
delete []value_buffer;
52+
delete []version_buffer;
53+
delete []freq_buffer;
54+
}
4355
}
4456
};
4557

0 commit comments

Comments
 (0)