Skip to content

Commit 9bd933d

Browse files
authored
Improve and fix fake_quantize_op (#13092)
* Improve and fix fake_quantize_op.
1 parent b4d4303 commit 9bd933d

File tree

5 files changed

+369
-376
lines changed

5 files changed

+369
-376
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ function(op_library TARGET)
178178
file(APPEND ${pybind_file} "USE_OP(relu);\n")
179179
elseif(${TARGET} STREQUAL "fake_dequantize")
180180
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
181+
elseif(${TARGET} STREQUAL "fake_quantize")
182+
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
181183
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
182184
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
183185
elseif(${TARGET} STREQUAL "fc")
@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
293295
op_library(flatten_op DEPS reshape_op)
294296
op_library(sequence_pad_op DEPS sequence_padding)
295297
op_library(unstack_op DEPS stack_op)
298+
op_library(fake_quantize_op DEPS memory)
296299

297300
if (WITH_GPU)
298301
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 171 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,86 +14,198 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fake_quantize_op.h"
1616
#include <string>
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/operators/clip_op.h"
19+
#include "paddle/fluid/platform/transform.h"
1720

1821
namespace paddle {
1922
namespace operators {
2023

21-
class FakeQuantizeOp : public framework::OperatorWithKernel {
24+
template <typename T, int MajorType = Eigen::RowMajor,
25+
typename IndexType = Eigen::DenseIndex>
26+
using EigenVectorArrayMap =
27+
Eigen::TensorMap<Eigen::Tensor<T, 1, MajorType, IndexType>>;
28+
29+
template <typename T, int MajorType = Eigen::RowMajor,
30+
typename IndexType = Eigen::DenseIndex>
31+
using ConstEigenVectorArrayMap =
32+
Eigen::TensorMap<const Eigen::Tensor<T, 1, MajorType, IndexType>>;
33+
34+
template <typename T>
35+
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
36+
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
37+
const int num, T* out) {
38+
Eigen::DSizes<Eigen::DenseIndex, 1> idim(num);
39+
Eigen::DSizes<Eigen::DenseIndex, 1> odim(1);
40+
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>> in_e(in, idim);
41+
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> out_e(out, odim);
42+
43+
out_e = in_e.abs().maximum();
44+
}
45+
};
46+
47+
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
48+
49+
template <typename T>
50+
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
51+
void operator()(const platform::CPUDeviceContext& ctx,
52+
const framework::Tensor& in, const framework::Tensor& scale,
53+
const int bin_cnt, framework::Tensor* out) {
54+
T s = scale.data<T>()[0];
55+
platform::Transform<platform::CPUDeviceContext> trans;
56+
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
57+
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
58+
auto in_e = framework::EigenVector<T>::Flatten(in);
59+
auto out_e = framework::EigenVector<T>::Flatten(*out);
60+
61+
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round();
62+
}
63+
};
64+
65+
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
66+
67+
template <typename T>
68+
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
69+
void operator()(const platform::CPUDeviceContext& ctx,
70+
const framework::Tensor& cur_scale,
71+
const framework::Tensor& last_scale,
72+
const framework::Tensor& iter, const int window_size,
73+
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
74+
T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
75+
int64_t it = iter.data<int64_t>()[0];
76+
int idx = it % window_size;
77+
T removed = scale_arr[idx];
78+
T cur = cur_scale.data<T>()[0];
79+
scale_arr[idx] = cur;
80+
81+
T max = last_scale.data<T>()[0];
82+
if (max < cur) {
83+
max = cur;
84+
} else if (fabs(removed - max) < 1e-6) {
85+
int size = (it > window_size) ? window_size : it;
86+
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
87+
&max);
88+
}
89+
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
90+
}
91+
};
92+
93+
template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
94+
95+
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
2296
public:
23-
FakeQuantizeOp(const std::string &type,
24-
const framework::VariableNameMap &inputs,
25-
const framework::VariableNameMap &outputs,
26-
const framework::AttributeMap &attrs)
97+
FakeQuantizeAbsMaxOp(const std::string& type,
98+
const framework::VariableNameMap& inputs,
99+
const framework::VariableNameMap& outputs,
100+
const framework::AttributeMap& attrs)
27101
: OperatorWithKernel(type, inputs, outputs, attrs) {}
28102

29-
void InferShape(framework::InferShapeContext *ctx) const override {
103+
void InferShape(framework::InferShapeContext* ctx) const override {
30104
PADDLE_ENFORCE(ctx->HasInput("X"),
31105
"Input(X) of FakeQuantizeOp should not be null.");
32106
PADDLE_ENFORCE(ctx->HasOutput("Out"),
33107
"Output(Out) of FakeQuantizeOp should not be null.");
34-
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
35-
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
36-
// if (ctx->HasInput("InMovingScale")) {
37-
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
38-
//}
39-
// if (ctx->HasInput("InScales")) {
40-
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
41-
"OutScales(Out) of FakeQuantizeOp should not be null");
42-
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
43-
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
44-
// ctx->Outputs("OutScales")[0],
45-
// "Mean and MeanOut should share the same memory");
46-
//}
108+
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
109+
"Output(Scale) of FakeQuantizeOp should not be null.");
47110
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
111+
ctx->SetOutputDim("OutScale", {1});
48112
ctx->ShareLoD("X", /*->*/ "Out");
49113
}
114+
115+
protected:
116+
framework::OpKernelType GetExpectedKernelType(
117+
const framework::ExecutionContext& ctx) const override {
118+
return framework::OpKernelType(
119+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
120+
ctx.device_context());
121+
}
50122
};
51123

52-
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
124+
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
53125
public:
54126
void Make() override {
55-
AddInput("X", "(Tensor) Input tensor of scale operator.");
56-
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
57-
.AsDispensable();
58-
AddInput("InMovingScale", "Last scale, used in static quantization.")
59-
.AsDispensable();
60-
AddInput("InCurrentIter",
61-
"Last iteration number, used in static quantization.")
62-
.AsDispensable();
63-
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
64-
AddOutput("OutScales",
65-
"(Tensor) scale buffer, used in static quantization.")
66-
.AsDispensable();
67-
AddOutput("OutMovingScale", " Current scale");
68-
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
69-
AddAttr<std::string>("quantize_type",
70-
"(string, default abs_max)"
71-
"The scaling tpe of the quantize operator.")
72-
.SetDefault("abs_max");
73-
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
127+
AddInput("X", "(Tensor) Input is float data type.");
128+
AddOutput("Out",
129+
"(Tensor) Output of quantized low level tensor, "
130+
"but also saved as float data type.");
131+
AddOutput("OutScale", "(Tensor) Current scale");
74132
AddAttr<int>("bit_length", "(int, default 8)")
75133
.SetDefault(8)
76-
.AddCustomChecker([](const int &bit_length) {
134+
.AddCustomChecker([](const int& bit_length) {
77135
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
78136
"'bit_length' should be between 1 and 16.");
79137
});
80-
AddAttr<bool>("is_test", "").SetDefault(false);
81138
AddComment(R"DOC(
82139
FakeQuantize operator
83140
84-
quantize_type = abs_max:
141+
$$scale = max(abs(X))$$
142+
$$range = 2^{bit_length - 1} - 1$$
143+
$$Out = round(X/scale * range)$$
85144
86-
$$scale = max(abs(x))$$
145+
)DOC");
146+
}
147+
};
87148

88-
quantize_type = range_abs_max:
149+
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
150+
public:
151+
FakeQuantizeRangeAbsMaxOp(const std::string& type,
152+
const framework::VariableNameMap& inputs,
153+
const framework::VariableNameMap& outputs,
154+
const framework::AttributeMap& attrs)
155+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
89156

90-
$$scale = max(max(abs(x)), history_abs_max)$$
157+
void InferShape(framework::InferShapeContext* ctx) const override {
158+
PADDLE_ENFORCE(ctx->HasInput("X"),
159+
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
160+
PADDLE_ENFORCE(
161+
ctx->HasOutput("Out"),
162+
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
163+
PADDLE_ENFORCE(
164+
ctx->HasOutput("OutScale"),
165+
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
166+
if (ctx->HasOutput("OutScales")) {
167+
int window_size = ctx->Attrs().Get<int>("window_size");
168+
ctx->SetOutputDim("OutScales", {window_size});
169+
}
170+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
171+
ctx->SetOutputDim("OutScale", {1});
172+
ctx->ShareLoD("X", /*->*/ "Out");
173+
}
91174

92-
quantize_type = moving_average_abs_max:
175+
protected:
176+
framework::OpKernelType GetExpectedKernelType(
177+
const framework::ExecutionContext& ctx) const override {
178+
return framework::OpKernelType(
179+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
180+
ctx.device_context());
181+
}
182+
};
93183

94-
$$scale = 0.1*scale+0.9*new_abs_max)$$
184+
class FakeQuantizeRangeAbsMaxOpMaker
185+
: public framework::OpProtoAndCheckerMaker {
186+
public:
187+
void Make() override {
188+
AddInput("X", "(Tensor) Input is float data type.");
189+
AddInput("InScale", "Last scale.");
190+
AddInput("Iter", "Global step iteration.").AsDispensable();
191+
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
192+
AddOutput("OutScale", " Current scale");
193+
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
194+
AddAttr<int>("window_size", "(int, default 10000) window range size.")
195+
.SetDefault(10000);
196+
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
197+
.SetDefault(8)
198+
.AddCustomChecker([](const int& bit_length) {
199+
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
200+
"'bit_length' should be between 1 and 16.");
201+
});
202+
AddAttr<bool>("is_test", "").SetDefault(false);
203+
AddComment(R"DOC(
204+
FakeQuantize operator is used in static quantization.
95205
96-
$$Out = scale*X$$
206+
$$scale = max(max(abs(x)), history_abs_max)$$
207+
$$range = 2^{bit_length - 1} - 1$$
208+
$$Out = round(X/scale * range)$$
97209
98210
)DOC");
99211
}
@@ -103,10 +215,16 @@ quantize_type = moving_average_abs_max:
103215
} // namespace paddle
104216

105217
namespace ops = paddle::operators;
218+
using CPU = paddle::platform::CPUDeviceContext;
219+
220+
REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
221+
ops::FakeQuantizeAbsMaxOpMaker,
222+
paddle::framework::EmptyGradOpMaker);
223+
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
224+
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
106225

107-
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
226+
REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
227+
ops::FakeQuantizeRangeAbsMaxOpMaker,
108228
paddle::framework::EmptyGradOpMaker);
109-
REGISTER_OP_CPU_KERNEL(
110-
fake_quantize,
111-
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
112-
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
229+
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
230+
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);

0 commit comments

Comments
 (0)