Skip to content

Commit 0353edd

Browse files
authored
Improve fake_dequantize_op. (#12877)
* Improve fake_dequantize_op. * Follow comments.
1 parent 11e01d9 commit 0353edd

File tree

4 files changed

+97
-32
lines changed

4 files changed

+97
-32
lines changed

paddle/fluid/operators/fake_dequantize_op.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,32 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
template <typename T>
22+
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
23+
void operator()(const platform::CPUDeviceContext& dev_ctx,
24+
const framework::Tensor* in, const framework::Tensor* scale,
25+
T max_range, framework::Tensor* out) {
26+
auto in_e = framework::EigenVector<T>::Flatten(*in);
27+
const T* scale_factor = scale->data<T>();
28+
auto out_e = framework::EigenVector<T>::Flatten(*out);
29+
30+
auto& dev = *dev_ctx.eigen_device();
31+
out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
32+
}
33+
};
34+
35+
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
36+
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
37+
2138
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
2239
public:
23-
FakeDequantizeMaxAbsOp(const std::string &type,
24-
const framework::VariableNameMap &inputs,
25-
const framework::VariableNameMap &outputs,
26-
const framework::AttributeMap &attrs)
40+
FakeDequantizeMaxAbsOp(const std::string& type,
41+
const framework::VariableNameMap& inputs,
42+
const framework::VariableNameMap& outputs,
43+
const framework::AttributeMap& attrs)
2744
: OperatorWithKernel(type, inputs, outputs, attrs) {}
2845

29-
void InferShape(framework::InferShapeContext *ctx) const override {
46+
void InferShape(framework::InferShapeContext* ctx) const override {
3047
PADDLE_ENFORCE(ctx->HasInput("X"),
3148
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
3249
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
4259
AddInput("X",
4360
"(Tensor) The input with float-32/64 type is the "
4461
"low precision tensor.");
62+
AddInput("Scale", "(float) The scale in quantization stage.");
4563
AddOutput("Out",
4664
"(Tensor) The output is the dequantized high "
4765
"precision tensor.");
48-
AddAttr<int>("num_bits",
49-
"(int) `num_bits` is the quantization level bits, "
50-
"such as 2, 5, 8.");
51-
AddAttr<float>("scale",
52-
"(float) The maximum absolute value of low precision tensor."
53-
"It is usually calculated by the fake_quantize_max_abs_op.");
66+
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
5467
AddComment(R"DOC(
5568
FakeDequantizeMaxAbsOp operator.
5669
5770
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
5871
59-
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
72+
$$Out = \frac{scale*X}{ max_range }$$
6073
6174
)DOC");
6275
}

paddle/fluid/operators/fake_dequantize_op.cu

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,42 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fake_dequantize_op.h"
1616

17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num,
22+
T* out) {
23+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
24+
if (idx < num) {
25+
out[idx] = in[idx] * scale[0] / max_range;
26+
}
27+
}
28+
29+
template <typename T>
30+
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
31+
void operator()(const platform::CUDADeviceContext& dev_ctx,
32+
const framework::Tensor* in, const framework::Tensor* scale,
33+
T max_range, framework::Tensor* out) {
34+
const T* in_data = in->data<T>();
35+
const T* scale_factor = scale->data<T>();
36+
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
37+
38+
int num = in->numel();
39+
int block = 512;
40+
int grid = (num + block - 1) / block;
41+
42+
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
43+
in_data, scale_factor, max_range, num, out_data);
44+
}
45+
};
46+
47+
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
48+
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
49+
50+
} // namespace operators
51+
} // namespace paddle
52+
1753
namespace ops = paddle::operators;
1854
using CUDA = paddle::platform::CUDADeviceContext;
1955
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,

paddle/fluid/operators/fake_dequantize_op.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,29 @@ limitations under the License. */
1919

2020
namespace paddle {
2121
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
struct DequantizeFunctor {
25+
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
26+
const framework::Tensor* scale, T max_range,
27+
framework::Tensor* out);
28+
};
29+
2230
template <typename DeviceContext, typename T>
2331
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
2432
public:
2533
virtual void Compute(const framework::ExecutionContext& ctx) const {
2634
auto* in = ctx.Input<framework::Tensor>("X");
35+
auto* scale = ctx.Input<framework::Tensor>("Scale");
2736
auto* out = ctx.Output<framework::Tensor>("Out");
28-
out->mutable_data<T>(in->place());
2937

30-
int num_bits = ctx.Attr<int>("num_bits");
31-
T scale = static_cast<T>(ctx.Attr<float>("scale"));
32-
int range = std::pow(2, num_bits) - 1;
38+
float max_range = ctx.Attr<float>("max_range");
39+
40+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
41+
out->mutable_data<T>(dev_ctx.GetPlace());
3342

34-
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
35-
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
36-
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
37-
eigen_out.device(dev) = (scale / range) * eigen_in;
43+
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale,
44+
static_cast<T>(max_range), out);
3845
}
3946
};
4047

python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,50 @@
2020
from op_test import OpTest
2121

2222

23-
def quantize_max_abs(x, num_bits):
24-
range = math.pow(2, num_bits) - 1
23+
def quantize_max_abs(x, max_range):
2524
scale = np.max(np.abs(x).flatten())
26-
y = np.round(x / scale * range)
25+
y = np.round(x / scale * max_range)
2726
return y, scale
2827

2928

30-
def dequantize_max_abs(x, num_bits, scale):
31-
range = math.pow(2, num_bits) - 1
32-
y = (scale / range) * x
29+
def dequantize_max_abs(x, scale, max_range):
30+
y = (scale / max_range) * x
3331
return y
3432

3533

3634
class TestFakeDequantizeMaxAbsOp(OpTest):
3735
def set_args(self):
3836
self.num_bits = 8
37+
self.max_range = math.pow(2, self.num_bits - 1) - 1
38+
self.data_type = "float32"
3939

4040
def setUp(self):
4141
self.set_args()
4242
self.op_type = "fake_dequantize_max_abs"
43-
x = np.random.randn(31, 65).astype("float32")
44-
yq, scale = quantize_max_abs(x, self.num_bits)
45-
ydq = dequantize_max_abs(yq, self.num_bits, scale)
43+
x = np.random.randn(31, 65).astype(self.data_type)
44+
yq, scale = quantize_max_abs(x, self.max_range)
45+
ydq = dequantize_max_abs(yq, scale, self.max_range)
4646

47-
self.inputs = {'X': yq}
48-
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
47+
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)}
48+
self.attrs = {'max_range': self.max_range}
4949
self.outputs = {'Out': ydq}
5050

5151
def test_check_output(self):
5252
self.check_output()
5353

5454

55-
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
55+
class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
56+
def set_args(self):
57+
self.num_bits = 8
58+
self.max_range = math.pow(2, self.num_bits - 1) - 1
59+
self.data_type = "float64"
60+
61+
62+
class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
5663
def set_args(self):
5764
self.num_bits = 5
65+
self.max_range = math.pow(2, self.num_bits - 1) - 1
66+
self.data_type = "float32"
5867

5968

6069
if __name__ == "__main__":

0 commit comments

Comments
 (0)