Skip to content

Commit ec88b6c

Browse files
committed
add channel wise quantization in ir pass.
1 parent 81b4fad commit ec88b6c

File tree

7 files changed

+290
-70
lines changed

7 files changed

+290
-70
lines changed

paddle/fluid/operators/fake_dequantize_op.h

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <vector>
18+
#include "paddle/fluid/framework/ddim.h"
1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
2021

@@ -54,26 +55,45 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
5455
auto scales = ctx.MultiInput<framework::Tensor>("Scales");
5556
auto* out = ctx.Output<framework::Tensor>("Out");
5657

57-
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
58-
"The number of first scale values must be the same with "
59-
"first dimension value of Input(X).");
60-
6158
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
6259
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
6360

6461
auto& dev_ctx = ctx.template device_context<DeviceContext>();
6562
out->mutable_data<T>(dev_ctx.GetPlace());
6663

6764
auto dequant = DequantizeFunctor<DeviceContext, T>();
68-
for (int64_t i = 0; i < in->dims()[0]; i++) {
69-
framework::Tensor one_channel_in = in->Slice(i, i + 1);
70-
framework::Tensor one_channel_out = out->Slice(i, i + 1);
71-
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
72-
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
73-
static_cast<T>(max_range), &one_channel_out);
74-
}
75-
76-
if (scales.size() == 2) {
65+
if (scales.size() == 1) {
66+
PADDLE_ENFORCE_EQ(
67+
scales[0]->numel(), in->dims()[0],
68+
"The number of first scale values must be the same with "
69+
"first dimension value of Input(X) when the `Scales` has only one "
70+
"element.");
71+
for (int64_t i = 0; i < in->dims()[0]; i++) {
72+
framework::Tensor one_channel_in = in->Slice(i, i + 1);
73+
framework::Tensor one_channel_out = out->Slice(i, i + 1);
74+
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
75+
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
76+
static_cast<T>(max_range), &one_channel_out);
77+
}
78+
} else if (scales.size() == 2) {
79+
PADDLE_ENFORCE_EQ(
80+
scales[0]->numel(), in->dims()[1],
81+
"The number of first scale values must be the same with "
82+
"second dimension value of Input(X) when the `Scales` has two "
83+
"elements.");
84+
for (int64_t i = 0; i < in->dims()[0]; i++) {
85+
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
86+
framework::slice_ddim(in->dims(), 1, in->dims().size()));
87+
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
88+
framework::slice_ddim(out->dims(), 1, out->dims().size()));
89+
for (int64_t j = 0; j < in->dims()[1]; j++) {
90+
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
91+
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
92+
framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
93+
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
94+
static_cast<T>(max_range), &one_channel_out);
95+
}
96+
}
7797
PADDLE_ENFORCE_EQ(
7898
scales[1]->numel(), 1,
7999
"The second scale tensor should only have one value at now.");

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
169169
ctx->HasOutput("Out"),
170170
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
171171
PADDLE_ENFORCE(
172-
ctx->HasOutput("OutScales"),
173-
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null.");
172+
ctx->HasOutput("OutScale"),
173+
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
174174
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
175-
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]});
175+
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
176176
ctx->ShareLoD("X", /*->*/ "Out");
177177
}
178178

@@ -192,7 +192,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
192192
AddOutput("Out",
193193
"(Tensor) Output of quantized low level tensor, "
194194
"but also saved as float data type.");
195-
AddOutput("OutScales", "(Tensor) Current channel wise scale");
195+
AddOutput("OutScale", "(Tensor) Current channel wise scale");
196196
AddAttr<int>("bit_length", "(int, default 8)")
197197
.SetDefault(8)
198198
.AddCustomChecker([](const int& bit_length) {

paddle/fluid/operators/fake_quantize_op.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
7878
auto* in = context.Input<framework::Tensor>("X");
7979

8080
auto* out = context.Output<framework::Tensor>("Out");
81-
auto* out_scales = context.Output<framework::Tensor>("OutScales");
82-
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace());
81+
auto* out_scale = context.Output<framework::Tensor>("OutScale");
82+
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
8383
out->mutable_data<T>(context.GetPlace());
8484

8585
int bit_length = context.Attr<int>("bit_length");
@@ -91,13 +91,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
9191
framework::Tensor one_channel = in->Slice(i, i + 1);
9292
const T* one_channel_data = one_channel.data<T>();
9393
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
94-
&out_scales_data[i]);
94+
&out_scale_data[i]);
9595
}
9696
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
9797
for (int64_t i = 0; i < in->dims()[0]; i++) {
9898
framework::Tensor one_channel_in = in->Slice(i, i + 1);
9999
framework::Tensor one_channel_out = out->Slice(i, i + 1);
100-
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
100+
framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1);
101101
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
102102
&one_channel_out);
103103
}

0 commit comments

Comments
 (0)