Skip to content

Commit 3a29821

Browse files
authored
Develop a fake dequantized op for fixed-point quantization training framework. (#10965)
* Develop a fake dequantized op for fixed-point quantization training framework. * Add the missing file.
1 parent 66ec827 commit 3a29821

File tree

5 files changed

+201
-0
lines changed

5 files changed

+201
-0
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ function(op_library TARGET)
168168
file(APPEND ${pybind_file} "USE_OP(relu);\n")
169169
elseif(${TARGET} STREQUAL "reduce")
170170
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
171+
elseif(${TARGET} STREQUAL "fake_dequantize")
172+
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
171173
else()
172174
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
173175
endif()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fake_dequantize_op.h"
16+
#include <string>
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
22+
public:
23+
FakeDequantizeMaxAbsOp(const std::string &type,
24+
const framework::VariableNameMap &inputs,
25+
const framework::VariableNameMap &outputs,
26+
const framework::AttributeMap &attrs)
27+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
28+
29+
void InferShape(framework::InferShapeContext *ctx) const override {
30+
PADDLE_ENFORCE(ctx->HasInput("X"),
31+
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
33+
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
34+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
35+
ctx->ShareLoD("X", /*->*/ "Out");
36+
}
37+
};
38+
39+
class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
40+
public:
41+
void Make() override {
42+
AddInput("X",
43+
"(Tensor) The input with float-32/64 type is the "
44+
"low precision tensor.");
45+
AddOutput("Out",
46+
"(Tensor) The output is the dequantized high "
47+
"precision tensor.");
48+
AddAttr<int>("num_bits",
49+
"(int) `num_bits` is the quantization level bits, "
50+
"such as 2, 5, 8.");
51+
AddAttr<float>("scale",
52+
"(float) The maximum absolute value of low precision tensor."
53+
"It is usually calculated by the fake_quantize_max_abs_op.");
54+
AddComment(R"DOC(
55+
FakeDequantizeMaxAbsOp operator.
56+
57+
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
58+
59+
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
60+
61+
)DOC");
62+
}
63+
};
64+
65+
} // namespace operators
66+
} // namespace paddle
67+
68+
namespace ops = paddle::operators;
69+
using CPU = paddle::platform::CPUDeviceContext;
70+
71+
REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
72+
ops::FakeDequantizeMaxAbsOpMaker,
73+
paddle::framework::EmptyGradOpMaker);
74+
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
75+
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
76+
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fake_dequantize_op.h"
16+
17+
namespace ops = paddle::operators;
18+
using CUDA = paddle::platform::CUDADeviceContext;
19+
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
20+
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
21+
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
template <typename DeviceContext, typename T>
23+
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
24+
public:
25+
virtual void Compute(const framework::ExecutionContext& ctx) const {
26+
auto* in = ctx.Input<framework::Tensor>("X");
27+
auto* out = ctx.Output<framework::Tensor>("Out");
28+
out->mutable_data<T>(in->place());
29+
30+
int num_bits = ctx.Attr<int>("num_bits");
31+
T scale = static_cast<T>(ctx.Attr<float>("scale"));
32+
int range = std::pow(2, num_bits) - 1;
33+
34+
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
35+
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
36+
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
37+
eigen_out.device(dev) = (scale / range) * eigen_in;
38+
}
39+
};
40+
41+
} // namespace operators
42+
} // namespace paddle
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import math
18+
from op_test import OpTest
19+
20+
21+
def quantize_max_abs(x, num_bits):
22+
range = math.pow(2, num_bits) - 1
23+
scale = np.max(np.abs(x).flatten())
24+
y = np.round(x / scale * range)
25+
return y, scale
26+
27+
28+
def dequantize_max_abs(x, num_bits, scale):
29+
range = math.pow(2, num_bits) - 1
30+
y = (scale / range) * x
31+
return y
32+
33+
34+
class TestFakeDequantizeMaxAbsOp(OpTest):
35+
def set_args(self):
36+
self.num_bits = 8
37+
38+
def setUp(self):
39+
self.set_args()
40+
self.op_type = "fake_dequantize_max_abs"
41+
x = np.random.randn(31, 65).astype("float32")
42+
yq, scale = quantize_max_abs(x, self.num_bits)
43+
print 'scale ', scale
44+
ydq = dequantize_max_abs(yq, self.num_bits, scale)
45+
46+
self.inputs = {'X': yq}
47+
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
48+
self.outputs = {'Out': ydq}
49+
50+
def test_check_output(self):
51+
self.check_output()
52+
53+
54+
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
55+
def set_args(self):
56+
self.num_bits = 5
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)