Skip to content

Commit 3537966

Browse files
authored
[Quantization] Add checkpoint quantization in low precision optimization tools. (#449)
1 parent 46749b8 commit 3537966

File tree

7 files changed

+234
-65
lines changed

7 files changed

+234
-65
lines changed

tensorflow/c/quantize_embedding_variable.cc

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,8 @@ namespace tensorflow {
2525
namespace checkpoint {
2626

2727
void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
28-
const std::vector<string>& names,
29-
const std::set<string>& ev_suffix) {
30-
std::set<string> updated_names;
31-
for (int idx = 0; idx < names.size(); ++idx) {
32-
updated_names.insert(names[idx] + "-values");
33-
for (auto it = ev_suffix.cbegin(); it != ev_suffix.cend(); ++it) {
34-
updated_names.insert(names[idx] + *it);
35-
}
36-
}
37-
28+
const std::vector<string>& ignored_names) {
29+
std::set<string> excluded_names(ignored_names.cbegin(), ignored_names.cend());
3830
std::vector<std::string> tensor_names;
3931
reader.Seek(kHeaderEntryKey);
4032
reader.Next();
@@ -45,7 +37,7 @@ void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
4537
Status status;
4638
DataType dtype;
4739
TensorShape shape;
48-
if (updated_names.count(tensor_name)) continue;
40+
if (excluded_names.count(tensor_name)) continue;
4941
status = reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
5042
if (status.ok()) {
5143
Tensor tensor(dtype, shape);
@@ -55,6 +47,18 @@ void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
5547
}
5648
}
5749

50+
void WriteRestVariables(BundleReader& reader, BundleWriter& writer,
51+
const std::vector<string>& ignored_names,
52+
const std::set<string>& ev_suffix) {
53+
std::vector<string> ev_names;
54+
for (int idx = 0; idx < ignored_names.size(); ++idx) {
55+
for (auto it = ev_suffix.cbegin(); it != ev_suffix.cend(); ++it) {
56+
ev_names.push_back(ignored_names[idx] + *it);
57+
}
58+
}
59+
WriteRestVariables(reader, writer, ev_names);
60+
}
61+
5862
void ConvertToBF16Value(const Tensor& in_tensor, const string name,
5963
BundleWriter& writer) {
6064
auto in_data = in_tensor.flat<float>();
@@ -120,19 +124,21 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
120124
const std::vector<string>& names,
121125
const std::vector<string>& quant_names,
122126
const std::vector<string>& scale_names,
123-
TF_DataType data_type) {
127+
const TF_DataType data_type,
128+
const bool is_ev) {
124129
BundleReader reader(Env::Default(), input_prefix);
125130
BundleWriter writer(Env::Default(), output_prefix);
126131
const std::set<string> ev_suffix = {
127132
"-freqs", "-freqs_filtered", "-keys",
128133
"-keys_filtered", "-partition_filter_offset", "-partition_offset",
129-
"-versions", "-versions_filtered"};
134+
"-versions", "-versions_filtered", "-values"};
130135

131136
for (int idx = 0; idx < names.size(); ++idx) {
132137
Status status;
133138
DataType dtype;
134139
TensorShape shape;
135-
string value_name = names[idx] + "-values";
140+
string suffix = is_ev ? "-values" : "";
141+
string value_name = names[idx] + suffix;
136142
status = reader.LookupDtypeAndShape(value_name, &dtype, &shape);
137143
if (!status.ok()) {
138144
errors::InvalidArgument("Invalid variable name:", value_name);
@@ -141,7 +147,7 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
141147
status = reader.Lookup(value_name, &in_tensor);
142148
auto in_data = in_tensor.flat<float>();
143149

144-
string quant_name = quant_names[idx] + "-values";
150+
string quant_name = quant_names[idx] + suffix;
145151
if (data_type == TF_DataType::TF_BFLOAT16) {
146152
ConvertToBF16Value(in_tensor, quant_name, writer);
147153
} else if (data_type == TF_DataType::TF_HALF) {
@@ -151,20 +157,36 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
151157
} else {
152158
errors::InvalidArgument("Unsupported data type:", data_type);
153159
}
154-
for (auto it = ev_suffix.cbegin(); it != ev_suffix.cend(); ++it) {
155-
string tensor_name = names[idx] + *it;
156-
status = reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
157-
if (status.ok()) {
158-
Tensor tensor(dtype, shape);
159-
status = reader.Lookup(tensor_name, &tensor);
160+
if (is_ev) {
161+
for (auto it = ev_suffix.cbegin(); it != ev_suffix.cend(); ++it) {
162+
if (*it == "-values") continue;
163+
string tensor_name = names[idx] + *it;
164+
status = reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
160165
if (status.ok()) {
161-
writer.Add(quant_names[idx] + *it, tensor);
166+
Tensor tensor(dtype, shape);
167+
status = reader.Lookup(tensor_name, &tensor);
168+
if (status.ok()) {
169+
writer.Add(quant_names[idx] + *it, tensor);
170+
}
162171
}
163172
}
164173
}
165174
}
166175

167-
WriteRestVariables(reader, writer, names, ev_suffix);
176+
if (is_ev) {
177+
WriteRestVariables(reader, writer, names, ev_suffix);
178+
} else {
179+
WriteRestVariables(reader, writer, names);
180+
}
181+
writer.Finish();
182+
return Status::OK();
183+
}
184+
185+
Status RemoveVariable(const string& input_prefix, const string& output_prefix,
186+
const std::vector<string>& names) {
187+
BundleReader reader(Env::Default(), input_prefix);
188+
BundleWriter writer(Env::Default(), output_prefix);
189+
WriteRestVariables(reader, writer, names);
168190
writer.Finish();
169191
return Status::OK();
170192
}

tensorflow/c/quantize_embedding_variable.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ Status QuantizeEmbeddingVariable(const string& input_prefix,
3434
const std::vector<string>& names,
3535
const std::vector<string>& quant_names,
3636
const std::vector<string>& scale_names,
37-
TF_DataType data_type);
37+
const TF_DataType data_type, const bool is_ev);
38+
39+
Status RemoveVariable(const string& input_prefix, const string& output_prefix,
40+
const std::vector<string>& names);
3841

3942
} // namespace checkpoint
4043
} // namespace tensorflow

tensorflow/python/util/quantize_embedding_variable.i

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,41 @@ limitations under the License.
2626
%unignore tensorflow;
2727
%unignore tensorflow::checkpoint;
2828
%unignore QuantizeEmbeddingVariablesByName;
29+
%unignore RemoveVariablesByName;
2930

3031
%{
3132
void QuantizeEmbeddingVariablesByName(string input_prefix, string output_prefix,
3233
string names_string,
3334
string quant_names_string,
3435
string scale_names_string,
35-
TF_DataType data_type) {
36+
TF_DataType data_type, bool is_ev) {
3637
std::vector<string> names = tensorflow::str_util::Split(names_string, ',');
3738
std::vector<string> quant_names =
3839
tensorflow::str_util::Split(quant_names_string, ',');
3940
std::vector<string> scale_names =
4041
tensorflow::str_util::Split(scale_names_string, ',');
4142

4243
tensorflow::checkpoint::QuantizeEmbeddingVariable(
43-
input_prefix, output_prefix, names, quant_names, scale_names, data_type);
44+
input_prefix, output_prefix, names, quant_names, scale_names, data_type,
45+
is_ev);
4446
}
4547
%}
4648

4749
void QuantizeEmbeddingVariablesByName(string input_prefix, string output_prefix,
4850
string names_string,
4951
string quant_names_string,
5052
string scale_names_string,
51-
TF_DataType data_type);
53+
TF_DataType data_type, bool is_ev);
54+
55+
%{
56+
void RemoveVariablesByName(string input_prefix, string output_prefix,
57+
string names_string) {
58+
std::vector<string> names = tensorflow::str_util::Split(names_string, ',');
59+
tensorflow::checkpoint::RemoveVariable(input_prefix, output_prefix, names);
60+
}
61+
%}
62+
63+
void RemoveVariablesByName(string input_prefix, string output_prefix,
64+
string names_string);
5265

5366
%unignoreall

tensorflow/python/util/quantize_embedding_variable.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
"""Exposes the Python wrapper for quantize embedding variable."""
1616
from __future__ import absolute_import, division, print_function
1717

18-
from tensorflow.python.pywrap_tensorflow import QuantizeEmbeddingVariablesByName
18+
from tensorflow.python.pywrap_tensorflow import (
19+
QuantizeEmbeddingVariablesByName,
20+
RemoveVariablesByName,
21+
)
1922
from tensorflow.python.util import compat
2023

2124

2225
def quantize_by_name(
23-
input_prefix, output_prefix, names, quant_names, scale_names, dtype
26+
input_prefix, output_prefix, names, quant_names, scale_names, dtype, is_ev
2427
):
2528
"""Python wrapper for quantize embedding variable.
2629
@@ -31,6 +34,7 @@ def quantize_by_name(
3134
quant_names: List of quantized tensor names.
3235
scale_names: List of scale tensor names.
3336
dtype: tf.bfloat16 or tf.int8
37+
is_ev: Boolean. Whether variables are EmbeddingVariable.
3438
"""
3539
input_prefix = compat.as_bytes(input_prefix)
3640
output_prefix = compat.as_bytes(output_prefix)
@@ -44,4 +48,19 @@ def quantize_by_name(
4448
quant_names_string,
4549
scale_names_string,
4650
dtype.as_datatype_enum,
51+
is_ev,
4752
)
53+
54+
55+
def remove_variables_by_name(input_prefix, output_prefix, names):
56+
"""Python wrapper for remove variables.
57+
58+
Args:
59+
input_prefix: String. Prefix of input checkpoint.
60+
output_prefix: String. Prefix of output checkpoint.
61+
names: List of tensor names to be removed.
62+
"""
63+
input_prefix = compat.as_bytes(input_prefix)
64+
output_prefix = compat.as_bytes(output_prefix)
65+
names_string = compat.as_bytes(",".join(names))
66+
RemoveVariablesByName(input_prefix, output_prefix, names_string)

tools/low_precision_optimize/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,17 @@ for i in range(10):
6060
with open('calib_data.npy', 'wb') as f:
6161
np.save(f, calib_data)
6262
```
63+
64+
## 转换模型参数
65+
此外,鉴于部分使用场景中存在仅更新模型参数的需求,该工具提供单独量化模型参数的功能,使用方式如下:
66+
```python
67+
from low_precision_optimize import convert_ckpt
68+
69+
# 指定输入的待优化参数checkpoint
70+
ckpt_prefix = 'dlrm/new_variables/variables'
71+
# 指定输出的优化参数checkpoint
72+
save_prefix = 'dlrm/opt_variables/variables'
73+
# 指定前一环节中优化后的saved_model目录
74+
opt_model_path = 'dlrm/saved_model_opt'
75+
convert_ckpt(ckpt_prefix, save_prefix, opt_model_path)
76+
```

0 commit comments

Comments
 (0)