Skip to content

Commit 8965819

Browse files
committed
rewrite the cuda kernels of channel_wise_quant_op and channe_wise_dequant_op. test=develop
1 parent ec88b6c commit 8965819

File tree

8 files changed

+405
-116
lines changed

8 files changed

+405
-116
lines changed

paddle/fluid/operators/fake_dequantize_op.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
3333
}
3434
};
3535

36+
template <typename T>
37+
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
38+
void operator()(const platform::CPUDeviceContext& dev_ctx,
39+
const framework::Tensor* in, const framework::Tensor** scales,
40+
const int scale_num, T max_range, framework::Tensor* out) {
41+
if (scale_num == 1) {
42+
const int channel = in->dims()[0];
43+
const T* scale_factor = scales[0]->data<T>();
44+
for (int i = 0; i < channel; i++) {
45+
T s = scale_factor[i];
46+
framework::Tensor one_channel_in = in->Slice(i, i + 1);
47+
framework::Tensor one_channel_out = out->Slice(i, i + 1);
48+
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
49+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
50+
auto& dev = *dev_ctx.eigen_device();
51+
out_e.device(dev) = (s / max_range) * in_e;
52+
}
53+
} else if (scale_num == 2) {
54+
int batch_size = in->dims()[0];
55+
int channel = in->dims()[1];
56+
const T* scale_one = scales[0]->data<T>();
57+
const T* scale_two = scales[1]->data<T>();
58+
for (int i = 0; i < batch_size; i++) {
59+
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
60+
framework::slice_ddim(in->dims(), 1, in->dims().size()));
61+
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
62+
framework::slice_ddim(out->dims(), 1, out->dims().size()));
63+
for (int j = 0; j < channel; j++) {
64+
T s = scale_one[j];
65+
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
66+
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
67+
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
68+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
69+
auto& dev = *dev_ctx.eigen_device();
70+
out_e.device(dev) = (s * scale_two[0] / max_range) * in_e;
71+
}
72+
}
73+
}
74+
}
75+
};
76+
3677
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
3778
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
79+
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
80+
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
3881

3982
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
4083
public:

paddle/fluid/operators/fake_dequantize_op.cu

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
4444
}
4545
};
4646

47+
template <typename T>
48+
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
49+
int num, int channel, T* out) {
50+
int tid = threadIdx.x;
51+
int channel_size = num / channel;
52+
const T* in_c = in + blockIdx.x * channel_size;
53+
T* out_c = out + blockIdx.x * channel_size;
54+
for (int i = tid; i < channel_size; i += blockDim.x) {
55+
out_c[i] = in_c[i] * scale[blockIdx.x] / max_range;
56+
}
57+
}
58+
59+
template <typename T>
60+
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
61+
const T* scale_two, T max_range, int num,
62+
int batch_size, int channel, T* out) {
63+
int tid = threadIdx.x;
64+
int channel_size = num / (batch_size * channel);
65+
int scale_index = blockIdx.x % channel;
66+
const T* in_c = in + blockIdx.x * channel_size;
67+
T* out_c = out + blockIdx.x * channel_size;
68+
for (int i = tid; i < channel_size; i += blockDim.x) {
69+
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
70+
}
71+
}
72+
73+
template <typename T>
74+
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
75+
void operator()(const platform::CUDADeviceContext& dev_ctx,
76+
const framework::Tensor* in, const framework::Tensor** scales,
77+
const int scale_num, T max_range, framework::Tensor* out) {
78+
const T* in_data = in->data<T>();
79+
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
80+
if (scale_num == 1) {
81+
int num = in->numel();
82+
int channel = in->dims()[0];
83+
const T* scale_factor = scales[0]->data<T>();
84+
int block = 1024;
85+
int grid = channel;
86+
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
87+
in_data, scale_factor, max_range, num, channel, out_data);
88+
} else if (scale_num == 2) {
89+
int num = in->numel();
90+
int batch_size = in->dims()[0];
91+
int channel = in->dims()[1];
92+
const T* scale_one = scales[0]->data<T>();
93+
const T* scale_two = scales[1]->data<T>();
94+
int block = 1024;
95+
int grid = batch_size * channel;
96+
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
97+
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
98+
out_data);
99+
}
100+
}
101+
};
102+
47103
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
48104
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
105+
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, float>;
106+
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, double>;
49107

50108
} // namespace operators
51109
} // namespace paddle

paddle/fluid/operators/fake_dequantize_op.h

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ struct DequantizeFunctor {
2929
framework::Tensor* out);
3030
};
3131

32+
template <typename DeviceContext, typename T>
33+
struct ChannelDequantizeFunctor {
34+
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
35+
const framework::Tensor** scales, const int scale_num,
36+
T max_range, framework::Tensor* out);
37+
};
38+
3239
template <typename DeviceContext, typename T>
3340
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
3441
public:
@@ -56,50 +63,32 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
5663
auto* out = ctx.Output<framework::Tensor>("Out");
5764

5865
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
59-
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
66+
int max_range = 1;
6067

6168
auto& dev_ctx = ctx.template device_context<DeviceContext>();
6269
out->mutable_data<T>(dev_ctx.GetPlace());
63-
64-
auto dequant = DequantizeFunctor<DeviceContext, T>();
65-
if (scales.size() == 1) {
70+
int scale_num = scales.size();
71+
if (scale_num == 1) {
6672
PADDLE_ENFORCE_EQ(
6773
scales[0]->numel(), in->dims()[0],
6874
"The number of first scale values must be the same with "
6975
"first dimension value of Input(X) when the `Scales` has only one "
7076
"element.");
71-
for (int64_t i = 0; i < in->dims()[0]; i++) {
72-
framework::Tensor one_channel_in = in->Slice(i, i + 1);
73-
framework::Tensor one_channel_out = out->Slice(i, i + 1);
74-
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
75-
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
76-
static_cast<T>(max_range), &one_channel_out);
77-
}
78-
} else if (scales.size() == 2) {
77+
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
78+
} else if (scale_num == 2) {
7979
PADDLE_ENFORCE_EQ(
8080
scales[0]->numel(), in->dims()[1],
8181
"The number of first scale values must be the same with "
8282
"second dimension value of Input(X) when the `Scales` has two "
8383
"elements.");
84-
for (int64_t i = 0; i < in->dims()[0]; i++) {
85-
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
86-
framework::slice_ddim(in->dims(), 1, in->dims().size()));
87-
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
88-
framework::slice_ddim(out->dims(), 1, out->dims().size()));
89-
for (int64_t j = 0; j < in->dims()[1]; j++) {
90-
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
91-
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
92-
framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
93-
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
94-
static_cast<T>(max_range), &one_channel_out);
95-
}
96-
}
9784
PADDLE_ENFORCE_EQ(
9885
scales[1]->numel(), 1,
9986
"The second scale tensor should only have one value at now.");
100-
max_range = std::pow(2, quant_bits[1] - 1) - 1;
101-
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out);
87+
max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
88+
(std::pow(2, quant_bits[1] - 1) - 1);
10289
}
90+
ChannelDequantizeFunctor<DeviceContext, T>()(
91+
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
10392
}
10493
};
10594

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
3737

3838
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
3939

40+
template <typename T>
41+
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
42+
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
43+
const int num, const int channel, T* out) {
44+
const int channel_size = num / channel;
45+
for (int i = 0; i < channel; i++) {
46+
auto* start = in + i * channel_size;
47+
auto* end = in + (i + 1) * channel_size;
48+
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
49+
}
50+
}
51+
};
52+
53+
template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
54+
4055
template <typename T>
4156
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
4257
void operator()(const platform::CPUDeviceContext& ctx,
@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
5368

5469
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
5570

71+
template <typename T>
72+
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
73+
void operator()(const platform::CPUDeviceContext& ctx,
74+
const framework::Tensor& in, const framework::Tensor& scale,
75+
const int bin_cnt, const int channel,
76+
framework::Tensor* out) {
77+
auto* scale_data = scale.data<T>();
78+
auto* in_data = in.data<T>();
79+
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
80+
const int channel_size = in.numel() / channel;
81+
platform::Transform<platform::CPUDeviceContext> trans;
82+
for (int i = 0; i < channel; i++) {
83+
T s = scale_data[i];
84+
auto* start = in_data + i * channel_size;
85+
auto* end = in_data + (i + 1) * channel_size;
86+
trans(ctx, start, end, out_data + i * channel_size,
87+
ClipFunctor<T>(-s, s));
88+
}
89+
for (int i = 0; i < channel; i++) {
90+
T s = scale_data[i];
91+
framework::Tensor one_channel_out = out->Slice(i, i + 1);
92+
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
93+
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
94+
}
95+
}
96+
};
97+
98+
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
99+
float>;
100+
56101
template <typename T>
57102
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
58103
void operator()(const platform::CPUDeviceContext& ctx,

0 commit comments

Comments
 (0)