@@ -25,16 +25,8 @@ namespace tensorflow {
2525namespace checkpoint {
2626
2727void WriteRestVariables (BundleReader& reader, BundleWriter& writer,
28- const std::vector<string>& names,
29- const std::set<string>& ev_suffix) {
30- std::set<string> updated_names;
31- for (int idx = 0 ; idx < names.size (); ++idx) {
32- updated_names.insert (names[idx] + " -values" );
33- for (auto it = ev_suffix.cbegin (); it != ev_suffix.cend (); ++it) {
34- updated_names.insert (names[idx] + *it);
35- }
36- }
37-
28+ const std::vector<string>& ignored_names) {
29+ std::set<string> excluded_names (ignored_names.cbegin (), ignored_names.cend ());
3830 std::vector<std::string> tensor_names;
3931 reader.Seek (kHeaderEntryKey );
4032 reader.Next ();
@@ -45,7 +37,7 @@ void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
4537 Status status;
4638 DataType dtype;
4739 TensorShape shape;
48- if (updated_names .count (tensor_name)) continue ;
40+ if (excluded_names .count (tensor_name)) continue ;
4941 status = reader.LookupDtypeAndShape (tensor_name, &dtype, &shape);
5042 if (status.ok ()) {
5143 Tensor tensor (dtype, shape);
@@ -55,6 +47,18 @@ void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
5547 }
5648}
5749
50+ void WriteRestVariables (BundleReader& reader, BundleWriter& writer,
51+ const std::vector<string>& ignored_names,
52+ const std::set<string>& ev_suffix) {
53+ std::vector<string> ev_names;
54+ for (int idx = 0 ; idx < ignored_names.size (); ++idx) {
55+ for (auto it = ev_suffix.cbegin (); it != ev_suffix.cend (); ++it) {
56+ ev_names.push_back (ignored_names[idx] + *it);
57+ }
58+ }
59+ WriteRestVariables (reader, writer, ev_names);
60+ }
61+
5862void ConvertToBF16Value (const Tensor& in_tensor, const string name,
5963 BundleWriter& writer) {
6064 auto in_data = in_tensor.flat <float >();
@@ -120,19 +124,21 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
120124 const std::vector<string>& names,
121125 const std::vector<string>& quant_names,
122126 const std::vector<string>& scale_names,
123- TF_DataType data_type) {
127+ const TF_DataType data_type,
128+ const bool is_ev) {
124129 BundleReader reader (Env::Default (), input_prefix);
125130 BundleWriter writer (Env::Default (), output_prefix);
126131 const std::set<string> ev_suffix = {
127132 " -freqs" , " -freqs_filtered" , " -keys" ,
128133 " -keys_filtered" , " -partition_filter_offset" , " -partition_offset" ,
129- " -versions" , " -versions_filtered" };
134+ " -versions" , " -versions_filtered" , " -values " };
130135
131136 for (int idx = 0 ; idx < names.size (); ++idx) {
132137 Status status;
133138 DataType dtype;
134139 TensorShape shape;
135- string value_name = names[idx] + " -values" ;
140+ string suffix = is_ev ? " -values" : " " ;
141+ string value_name = names[idx] + suffix;
136142 status = reader.LookupDtypeAndShape (value_name, &dtype, &shape);
137143 if (!status.ok ()) {
138144 errors::InvalidArgument (" Invalid variable name:" , value_name);
@@ -141,7 +147,7 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
141147 status = reader.Lookup (value_name, &in_tensor);
142148 auto in_data = in_tensor.flat <float >();
143149
144- string quant_name = quant_names[idx] + " -values " ;
150+ string quant_name = quant_names[idx] + suffix ;
145151 if (data_type == TF_DataType::TF_BFLOAT16) {
146152 ConvertToBF16Value (in_tensor, quant_name, writer);
147153 } else if (data_type == TF_DataType::TF_HALF) {
@@ -151,20 +157,36 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
151157 } else {
152158 errors::InvalidArgument (" Unsupported data type:" , data_type);
153159 }
154- for (auto it = ev_suffix.cbegin (); it != ev_suffix.cend (); ++it) {
155- string tensor_name = names[idx] + *it;
156- status = reader.LookupDtypeAndShape (tensor_name, &dtype, &shape);
157- if (status.ok ()) {
158- Tensor tensor (dtype, shape);
159- status = reader.Lookup (tensor_name, &tensor);
160+ if (is_ev) {
161+ for (auto it = ev_suffix.cbegin (); it != ev_suffix.cend (); ++it) {
162+ if (*it == " -values" ) continue ;
163+ string tensor_name = names[idx] + *it;
164+ status = reader.LookupDtypeAndShape (tensor_name, &dtype, &shape);
160165 if (status.ok ()) {
161- writer.Add (quant_names[idx] + *it, tensor);
166+ Tensor tensor (dtype, shape);
167+ status = reader.Lookup (tensor_name, &tensor);
168+ if (status.ok ()) {
169+ writer.Add (quant_names[idx] + *it, tensor);
170+ }
162171 }
163172 }
164173 }
165174 }
166175
167- WriteRestVariables (reader, writer, names, ev_suffix);
176+ if (is_ev) {
177+ WriteRestVariables (reader, writer, names, ev_suffix);
178+ } else {
179+ WriteRestVariables (reader, writer, names);
180+ }
181+ writer.Finish ();
182+ return Status::OK ();
183+ }
184+
185+ Status RemoveVariable (const string& input_prefix, const string& output_prefix,
186+ const std::vector<string>& names) {
187+ BundleReader reader (Env::Default (), input_prefix);
188+ BundleWriter writer (Env::Default (), output_prefix);
189+ WriteRestVariables (reader, writer, names);
168190 writer.Finish ();
169191 return Status::OK ();
170192}
0 commit comments