1+ /* Copyright 2022 The DeepRec Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ======================================================================*/
15+ #include " tensorflow/core/framework/embedding/embedding_var_ckpt_data.h"
16+ #include " tensorflow/core/framework/embedding/embedding_var_dump_iterator.h"
17+ #include " tensorflow/core/kernels/save_restore_tensor.h"
18+ #include " tensorflow/core/framework/register_types.h"
19+
20+ namespace tensorflow {
21+ namespace embedding {
22+ template <class K , class V >
23+ void EmbeddingVarCkptData<K, V>::Emplace(
24+ K key, ValuePtr<V>* value_ptr,
25+ const EmbeddingConfig& emb_config,
26+ V* default_value, int64 value_offset,
27+ bool is_save_freq,
28+ bool is_save_version,
29+ bool save_unfiltered_features) {
30+ if ((int64)value_ptr == ValuePtrStatus::IS_DELETED)
31+ return ;
32+
33+ V* primary_val = value_ptr->GetValue (0 , 0 );
34+ bool is_not_admit =
35+ primary_val == nullptr
36+ && emb_config.filter_freq != 0 ;
37+
38+ if (!is_not_admit) {
39+ key_vec_.emplace_back (key);
40+
41+ if (primary_val == nullptr ) {
42+ value_ptr_vec_.emplace_back (default_value);
43+ } else if (
44+ (int64)primary_val == ValuePosition::NOT_IN_DRAM) {
45+ value_ptr_vec_.emplace_back ((V*)ValuePosition::NOT_IN_DRAM);
46+ } else {
47+ V* val = value_ptr->GetValue (emb_config.emb_index ,
48+ value_offset);
49+ value_ptr_vec_.emplace_back (val);
50+ }
51+
52+
53+ if (is_save_version) {
54+ int64 dump_version = value_ptr->GetStep ();
55+ version_vec_.emplace_back (dump_version);
56+ }
57+
58+ if (is_save_freq) {
59+ int64 dump_freq = value_ptr->GetFreq ();
60+ freq_vec_.emplace_back (dump_freq);
61+ }
62+ } else {
63+ if (!save_unfiltered_features)
64+ return ;
65+
66+ key_filter_vec_.emplace_back (key);
67+
68+ if (is_save_version) {
69+ int64 dump_version = value_ptr->GetStep ();
70+ version_filter_vec_.emplace_back (dump_version);
71+ }
72+
73+ int64 dump_freq = value_ptr->GetFreq ();
74+ freq_filter_vec_.emplace_back (dump_freq);
75+ }
76+ }
77+ #define REGISTER_KERNELS (ktype, vtype ) \
78+ template void EmbeddingVarCkptData<ktype, vtype>::Emplace( \
79+ ktype, ValuePtr<vtype>*, const EmbeddingConfig&, \
80+ vtype*, int64, bool , bool , bool );
81+ #define REGISTER_KERNELS_ALL_INDEX (type ) \
82+ REGISTER_KERNELS (int32, type) \
83+ REGISTER_KERNELS(int64, type)
84+ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
85+ #undef REGISTER_KERNELS_ALL_INDEX
86+ #undef REGISTER_KERNELS
87+
88+
89+ template <class K , class V >
90+ void EmbeddingVarCkptData<K, V>::Emplace(K key, V* value_ptr) {
91+ key_vec_.emplace_back (key);
92+ value_ptr_vec_.emplace_back (value_ptr);
93+ }
94+ #define REGISTER_KERNELS (ktype, vtype ) \
95+ template void EmbeddingVarCkptData<ktype, vtype>::Emplace( \
96+ ktype, vtype*);
97+ #define REGISTER_KERNELS_ALL_INDEX (type ) \
98+ REGISTER_KERNELS (int32, type) \
99+ REGISTER_KERNELS(int64, type)
100+ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
101+ #undef REGISTER_KERNELS_ALL_INDEX
102+ #undef REGISTER_KERNELS
103+
104+ template <class K , class V >
105+ void EmbeddingVarCkptData<K, V>::SetWithPartition(
106+ std::vector<EmbeddingVarCkptData<K, V>>& ev_ckpt_data_parts) {
107+ part_offset_.resize (kSavedPartitionNum + 1 );
108+ part_filter_offset_.resize (kSavedPartitionNum + 1 );
109+ part_offset_[0 ] = 0 ;
110+ part_filter_offset_[0 ] = 0 ;
111+ for (int i = 0 ; i < kSavedPartitionNum ; i++) {
112+ part_offset_[i + 1 ] =
113+ part_offset_[i] + ev_ckpt_data_parts[i].key_vec_ .size ();
114+
115+ part_filter_offset_[i + 1 ] =
116+ part_filter_offset_[i] +
117+ ev_ckpt_data_parts[i].key_filter_vec_ .size ();
118+
119+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].key_vec_ .size (); j++) {
120+ key_vec_.emplace_back (ev_ckpt_data_parts[i].key_vec_ [j]);
121+ }
122+
123+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].value_ptr_vec_ .size (); j++) {
124+ value_ptr_vec_.emplace_back (ev_ckpt_data_parts[i].value_ptr_vec_ [j]);
125+ }
126+
127+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].version_vec_ .size (); j++) {
128+ version_vec_.emplace_back (ev_ckpt_data_parts[i].version_vec_ [j]);
129+ }
130+
131+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].freq_vec_ .size (); j++) {
132+ freq_vec_.emplace_back (ev_ckpt_data_parts[i].freq_vec_ [j]);
133+ }
134+
135+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].key_filter_vec_ .size (); j++) {
136+ key_filter_vec_.emplace_back (ev_ckpt_data_parts[i].key_filter_vec_ [j]);
137+ }
138+
139+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].version_filter_vec_ .size (); j++) {
140+ version_filter_vec_.emplace_back (ev_ckpt_data_parts[i].version_filter_vec_ [j]);
141+ }
142+
143+ for (int64 j = 0 ; j < ev_ckpt_data_parts[i].freq_filter_vec_ .size (); j++) {
144+ freq_filter_vec_.emplace_back (ev_ckpt_data_parts[i].freq_filter_vec_ [j]);
145+ }
146+ }
147+ }
148+
149+ #define REGISTER_KERNELS (ktype, vtype ) \
150+ template void EmbeddingVarCkptData<ktype, vtype>::SetWithPartition( \
151+ std::vector<EmbeddingVarCkptData<ktype, vtype>>&);
152+ #define REGISTER_KERNELS_ALL_INDEX (type ) \
153+ REGISTER_KERNELS (int32, type) \
154+ REGISTER_KERNELS(int64, type)
155+ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
156+ #undef REGISTER_KERNELS_ALL_INDEX
157+ #undef REGISTER_KERNELS
158+
159+ template <class K , class V >
160+ Status EmbeddingVarCkptData<K, V>::ExportToCkpt(
161+ const string& tensor_name,
162+ BundleWriter* writer,
163+ int64 value_len,
164+ ValueIterator<V>* value_iter) {
165+ size_t bytes_limit = 8 << 20 ;
166+ std::unique_ptr<char []> dump_buffer (new char [bytes_limit]);
167+
168+ EVVectorDataDumpIterator<K> key_dump_iter (key_vec_);
169+ Status s = SaveTensorWithFixedBuffer (
170+ tensor_name + " -keys" , writer, dump_buffer.get (),
171+ bytes_limit, &key_dump_iter,
172+ TensorShape ({key_vec_.size ()}));
173+ if (!s.ok ())
174+ return s;
175+
176+ EV2dVectorDataDumpIterator<V> value_dump_iter (
177+ value_ptr_vec_, value_len, value_iter);
178+ s = SaveTensorWithFixedBuffer (
179+ tensor_name + " -values" , writer, dump_buffer.get (),
180+ bytes_limit, &value_dump_iter,
181+ TensorShape ({value_ptr_vec_.size (), value_len}));
182+ if (!s.ok ())
183+ return s;
184+
185+ EVVectorDataDumpIterator<int64> version_dump_iter (version_vec_);
186+ s = SaveTensorWithFixedBuffer (
187+ tensor_name + " -versions" , writer, dump_buffer.get (),
188+ bytes_limit, &version_dump_iter,
189+ TensorShape ({version_vec_.size ()}));
190+ if (!s.ok ())
191+ return s;
192+
193+ EVVectorDataDumpIterator<int64> freq_dump_iter (freq_vec_);
194+ s = SaveTensorWithFixedBuffer (
195+ tensor_name + " -freqs" , writer, dump_buffer.get (),
196+ bytes_limit, &freq_dump_iter,
197+ TensorShape ({freq_vec_.size ()}));
198+ if (!s.ok ())
199+ return s;
200+
201+ EVVectorDataDumpIterator<K> filtered_key_dump_iter (key_filter_vec_);
202+ s = SaveTensorWithFixedBuffer (
203+ tensor_name + " -keys_filtered" , writer, dump_buffer.get (),
204+ bytes_limit, &filtered_key_dump_iter,
205+ TensorShape ({key_filter_vec_.size ()}));
206+ if (!s.ok ())
207+ return s;
208+
209+ EVVectorDataDumpIterator<int64>
210+ filtered_version_dump_iter (version_filter_vec_);
211+ s = SaveTensorWithFixedBuffer (
212+ tensor_name + " -versions_filtered" ,
213+ writer, dump_buffer.get (),
214+ bytes_limit, &filtered_version_dump_iter,
215+ TensorShape ({version_filter_vec_.size ()}));
216+ if (!s.ok ())
217+ return s;
218+
219+ EVVectorDataDumpIterator<int64>
220+ filtered_freq_dump_iter (freq_filter_vec_);
221+ s = SaveTensorWithFixedBuffer (
222+ tensor_name + " -freqs_filtered" ,
223+ writer, dump_buffer.get (),
224+ bytes_limit, &filtered_freq_dump_iter,
225+ TensorShape ({freq_filter_vec_.size ()}));
226+ if (!s.ok ())
227+ return s;
228+
229+ EVVectorDataDumpIterator<int32>
230+ part_offset_dump_iter (part_offset_);
231+ s = SaveTensorWithFixedBuffer (
232+ tensor_name + " -partition_offset" ,
233+ writer, dump_buffer.get (),
234+ bytes_limit, &part_offset_dump_iter,
235+ TensorShape ({part_offset_.size ()}));
236+ if (!s.ok ())
237+ return s;
238+
239+ EVVectorDataDumpIterator<int32>
240+ part_filter_offset_dump_iter (part_filter_offset_);
241+ s = SaveTensorWithFixedBuffer (
242+ tensor_name + " -partition_filter_offset" ,
243+ writer, dump_buffer.get (),
244+ bytes_limit, &part_filter_offset_dump_iter,
245+ TensorShape ({part_filter_offset_.size ()}));
246+ if (!s.ok ())
247+ return s;
248+
249+ return Status::OK ();
250+ }
251+
252+ #define REGISTER_KERNELS (ktype, vtype ) \
253+ template Status EmbeddingVarCkptData<ktype, vtype>::ExportToCkpt( \
254+ const string&, BundleWriter*, int64, ValueIterator<vtype>*);
255+ #define REGISTER_KERNELS_ALL_INDEX (type ) \
256+ REGISTER_KERNELS (int32, type) \
257+ REGISTER_KERNELS(int64, type)
258+ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
259+ #undef REGISTER_KERNELS_ALL_INDEX
260+ #undef REGISTER_KERNELS
261+ }// namespace embedding
262+ }// namespace tensorflow
0 commit comments