Skip to content

Commit ec11135

Browse files
authored
Merge pull request #16341 from wzzju/add_channel_wise_in_quant_pass
Add channel wise in quant pass.
2 parents e235882 + 8965819 commit ec11135

File tree

10 files changed

+655
-146
lines changed

10 files changed

+655
-146
lines changed

paddle/fluid/operators/fake_dequantize_op.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
3333
}
3434
};
3535

36+
template <typename T>
37+
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
38+
void operator()(const platform::CPUDeviceContext& dev_ctx,
39+
const framework::Tensor* in, const framework::Tensor** scales,
40+
const int scale_num, T max_range, framework::Tensor* out) {
41+
if (scale_num == 1) {
42+
const int channel = in->dims()[0];
43+
const T* scale_factor = scales[0]->data<T>();
44+
for (int i = 0; i < channel; i++) {
45+
T s = scale_factor[i];
46+
framework::Tensor one_channel_in = in->Slice(i, i + 1);
47+
framework::Tensor one_channel_out = out->Slice(i, i + 1);
48+
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
49+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
50+
auto& dev = *dev_ctx.eigen_device();
51+
out_e.device(dev) = (s / max_range) * in_e;
52+
}
53+
} else if (scale_num == 2) {
54+
int batch_size = in->dims()[0];
55+
int channel = in->dims()[1];
56+
const T* scale_one = scales[0]->data<T>();
57+
const T* scale_two = scales[1]->data<T>();
58+
for (int i = 0; i < batch_size; i++) {
59+
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
60+
framework::slice_ddim(in->dims(), 1, in->dims().size()));
61+
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
62+
framework::slice_ddim(out->dims(), 1, out->dims().size()));
63+
for (int j = 0; j < channel; j++) {
64+
T s = scale_one[j];
65+
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
66+
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
67+
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
68+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
69+
auto& dev = *dev_ctx.eigen_device();
70+
out_e.device(dev) = (s * scale_two[0] / max_range) * in_e;
71+
}
72+
}
73+
}
74+
}
75+
};
76+
3677
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
3778
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
79+
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
80+
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
3881

3982
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
4083
public:

paddle/fluid/operators/fake_dequantize_op.cu

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
4444
}
4545
};
4646

47+
template <typename T>
48+
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
49+
int num, int channel, T* out) {
50+
int tid = threadIdx.x;
51+
int channel_size = num / channel;
52+
const T* in_c = in + blockIdx.x * channel_size;
53+
T* out_c = out + blockIdx.x * channel_size;
54+
for (int i = tid; i < channel_size; i += blockDim.x) {
55+
out_c[i] = in_c[i] * scale[blockIdx.x] / max_range;
56+
}
57+
}
58+
59+
template <typename T>
60+
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
61+
const T* scale_two, T max_range, int num,
62+
int batch_size, int channel, T* out) {
63+
int tid = threadIdx.x;
64+
int channel_size = num / (batch_size * channel);
65+
int scale_index = blockIdx.x % channel;
66+
const T* in_c = in + blockIdx.x * channel_size;
67+
T* out_c = out + blockIdx.x * channel_size;
68+
for (int i = tid; i < channel_size; i += blockDim.x) {
69+
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
70+
}
71+
}
72+
73+
template <typename T>
74+
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
75+
void operator()(const platform::CUDADeviceContext& dev_ctx,
76+
const framework::Tensor* in, const framework::Tensor** scales,
77+
const int scale_num, T max_range, framework::Tensor* out) {
78+
const T* in_data = in->data<T>();
79+
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
80+
if (scale_num == 1) {
81+
int num = in->numel();
82+
int channel = in->dims()[0];
83+
const T* scale_factor = scales[0]->data<T>();
84+
int block = 1024;
85+
int grid = channel;
86+
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
87+
in_data, scale_factor, max_range, num, channel, out_data);
88+
} else if (scale_num == 2) {
89+
int num = in->numel();
90+
int batch_size = in->dims()[0];
91+
int channel = in->dims()[1];
92+
const T* scale_one = scales[0]->data<T>();
93+
const T* scale_two = scales[1]->data<T>();
94+
int block = 1024;
95+
int grid = batch_size * channel;
96+
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
97+
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
98+
out_data);
99+
}
100+
}
101+
};
102+
47103
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
48104
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
105+
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, float>;
106+
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, double>;
49107

50108
} // namespace operators
51109
} // namespace paddle

paddle/fluid/operators/fake_dequantize_op.h

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <vector>
18+
#include "paddle/fluid/framework/ddim.h"
1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
2021

@@ -28,6 +29,13 @@ struct DequantizeFunctor {
2829
framework::Tensor* out);
2930
};
3031

32+
template <typename DeviceContext, typename T>
33+
struct ChannelDequantizeFunctor {
34+
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
35+
const framework::Tensor** scales, const int scale_num,
36+
T max_range, framework::Tensor* out);
37+
};
38+
3139
template <typename DeviceContext, typename T>
3240
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
3341
public:
@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
5462
auto scales = ctx.MultiInput<framework::Tensor>("Scales");
5563
auto* out = ctx.Output<framework::Tensor>("Out");
5664

57-
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
58-
"The number of first scale values must be the same with "
59-
"first dimension value of Input(X).");
60-
6165
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
62-
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
66+
int max_range = 1;
6367

6468
auto& dev_ctx = ctx.template device_context<DeviceContext>();
6569
out->mutable_data<T>(dev_ctx.GetPlace());
66-
67-
auto dequant = DequantizeFunctor<DeviceContext, T>();
68-
for (int64_t i = 0; i < in->dims()[0]; i++) {
69-
framework::Tensor one_channel_in = in->Slice(i, i + 1);
70-
framework::Tensor one_channel_out = out->Slice(i, i + 1);
71-
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
72-
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
73-
static_cast<T>(max_range), &one_channel_out);
74-
}
75-
76-
if (scales.size() == 2) {
70+
int scale_num = scales.size();
71+
if (scale_num == 1) {
72+
PADDLE_ENFORCE_EQ(
73+
scales[0]->numel(), in->dims()[0],
74+
"The number of first scale values must be the same with "
75+
"first dimension value of Input(X) when the `Scales` has only one "
76+
"element.");
77+
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
78+
} else if (scale_num == 2) {
79+
PADDLE_ENFORCE_EQ(
80+
scales[0]->numel(), in->dims()[1],
81+
"The number of first scale values must be the same with "
82+
"second dimension value of Input(X) when the `Scales` has two "
83+
"elements.");
7784
PADDLE_ENFORCE_EQ(
7885
scales[1]->numel(), 1,
7986
"The second scale tensor should only have one value at now.");
80-
max_range = std::pow(2, quant_bits[1] - 1) - 1;
81-
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out);
87+
max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
88+
(std::pow(2, quant_bits[1] - 1) - 1);
8289
}
90+
ChannelDequantizeFunctor<DeviceContext, T>()(
91+
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
8392
}
8493
};
8594

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
3737

3838
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
3939

40+
template <typename T>
41+
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
42+
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
43+
const int num, const int channel, T* out) {
44+
const int channel_size = num / channel;
45+
for (int i = 0; i < channel; i++) {
46+
auto* start = in + i * channel_size;
47+
auto* end = in + (i + 1) * channel_size;
48+
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
49+
}
50+
}
51+
};
52+
53+
template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
54+
4055
template <typename T>
4156
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
4257
void operator()(const platform::CPUDeviceContext& ctx,
@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
5368

5469
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
5570

71+
template <typename T>
72+
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
73+
void operator()(const platform::CPUDeviceContext& ctx,
74+
const framework::Tensor& in, const framework::Tensor& scale,
75+
const int bin_cnt, const int channel,
76+
framework::Tensor* out) {
77+
auto* scale_data = scale.data<T>();
78+
auto* in_data = in.data<T>();
79+
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
80+
const int channel_size = in.numel() / channel;
81+
platform::Transform<platform::CPUDeviceContext> trans;
82+
for (int i = 0; i < channel; i++) {
83+
T s = scale_data[i];
84+
auto* start = in_data + i * channel_size;
85+
auto* end = in_data + (i + 1) * channel_size;
86+
trans(ctx, start, end, out_data + i * channel_size,
87+
ClipFunctor<T>(-s, s));
88+
}
89+
for (int i = 0; i < channel; i++) {
90+
T s = scale_data[i];
91+
framework::Tensor one_channel_out = out->Slice(i, i + 1);
92+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
93+
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
94+
}
95+
}
96+
};
97+
98+
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
99+
float>;
100+
56101
template <typename T>
57102
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
58103
void operator()(const platform::CPUDeviceContext& ctx,
@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
169214
ctx->HasOutput("Out"),
170215
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
171216
PADDLE_ENFORCE(
172-
ctx->HasOutput("OutScales"),
173-
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null.");
217+
ctx->HasOutput("OutScale"),
218+
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
174219
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
175-
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]});
220+
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
176221
ctx->ShareLoD("X", /*->*/ "Out");
177222
}
178223

@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
192237
AddOutput("Out",
193238
"(Tensor) Output of quantized low level tensor, "
194239
"but also saved as float data type.");
195-
AddOutput("OutScales", "(Tensor) Current channel wise scale");
240+
AddOutput("OutScale", "(Tensor) Current channel wise scale");
196241
AddAttr<int>("bit_length", "(int, default 8)")
197242
.SetDefault(8)
198243
.AddCustomChecker([](const int& bit_length) {

0 commit comments

Comments
 (0)