Skip to content

Commit e9bec93

Browse files
[slim] Add quantization strategy and distillation strategy. (#16408)
* Add fsp operator. 1 Add unitest. 2. Add python API. 3. Add layer test. * Add quantization strategy. 1. Add API. 2. Add unitest. * Add distillatoin strategy. * Add unitest config file for quantization * Fix Copyright test=develop * Fix setup.py * Fix document of layers.py. test=develop * Fix unitest in python3. test=develop * Fix documents. test=develop * 1. refine fsp op by batched gemm 2. remove unused import test=develop * Fix test_dist_se_resnext. 1. disable test distillation. 2. reset framework.py test=develop * Enable unitest of distillation after fixing Block._clone_variable test=develop * Fix cdn issue. test=develop
1 parent de3b70a commit e9bec93

File tree

24 files changed

+1214
-29
lines changed

24 files changed

+1214
-29
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label'
222222
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7'))
223223
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
224224
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
225+
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
225226
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
226227
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
227228
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))

paddle/fluid/operators/fsp_op.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fsp_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class FSPOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null.");
26+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of FSPOp should not be null.");
29+
30+
auto x_dims = ctx->GetInputDim("X");
31+
auto y_dims = ctx->GetInputDim("Y");
32+
33+
PADDLE_ENFORCE(
34+
x_dims.size() == 4,
35+
"The Input(X) must have shape [batch_size, channel, height, width].");
36+
PADDLE_ENFORCE(
37+
y_dims.size() == 4,
38+
"The Input(Y) must have shape [batch_size, channel, height, width].");
39+
PADDLE_ENFORCE(
40+
(x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]),
41+
"The Input(X) and Input(Y) should have the same height and width.");
42+
43+
ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]});
44+
ctx->ShareLoD("X", "Out");
45+
}
46+
47+
protected:
48+
framework::OpKernelType GetExpectedKernelType(
49+
const framework::ExecutionContext& ctx) const override {
50+
framework::LibraryType library_{framework::LibraryType::kPlain};
51+
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
52+
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
53+
ctx.device_context(), layout_, library_);
54+
}
55+
};
56+
57+
class FSPOpMaker : public framework::OpProtoAndCheckerMaker {
58+
public:
59+
void Make() override {
60+
AddInput("X",
61+
"(Tensor) The input of FSP op with shape [batch_size, x_channel, "
62+
"height, width]");
63+
AddInput("Y",
64+
"(Tensor) The input of FSP op with shape"
65+
"[batch_size, y_channel, height, width]."
66+
"The y_channel can be different with the x_channel of Input(X)"
67+
" while the other dimensions must be the same with Input(X)'s.");
68+
AddOutput(
69+
"Out",
70+
"(Tensor) The output of FSP op with shape "
71+
"[batch_size, x_channel, y_channel]. The x_channel is the channel "
72+
"of Input(X) and the y_channel is the channel of Input(Y).");
73+
AddComment(R"DOC(
74+
This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps.
75+
Given feature map x with shape [x_channel, h, w] and feature map y with shape
76+
[y_channel, h, w], we can get the fsp matrix of x and y in two steps:
77+
78+
step 1: reshape x into matrix with shape [x_channel, h * w] and reshape and
79+
transpose y into matrix with shape [h * w, y_channel]
80+
step 2: multiply x and y to get fsp matrix with shape [x_channel, y_channel]
81+
82+
The output is a batch of fsp matrices.
83+
)DOC");
84+
}
85+
};
86+
87+
class FSPOpGrad : public framework::OperatorWithKernel {
88+
public:
89+
using framework::OperatorWithKernel::OperatorWithKernel;
90+
91+
void InferShape(framework::InferShapeContext* ctx) const override {
92+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
93+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
94+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
95+
"Input(Out@GRAD) should not be null");
96+
auto x_dims = ctx->GetInputDim("X");
97+
auto y_dims = ctx->GetInputDim("Y");
98+
auto x_grad_name = framework::GradVarName("X");
99+
auto y_grad_name = framework::GradVarName("Y");
100+
if (ctx->HasOutput(x_grad_name)) {
101+
ctx->SetOutputDim(x_grad_name, x_dims);
102+
}
103+
if (ctx->HasOutput(y_grad_name)) {
104+
ctx->SetOutputDim(y_grad_name, y_dims);
105+
}
106+
}
107+
108+
framework::OpKernelType GetExpectedKernelType(
109+
const framework::ExecutionContext& ctx) const override {
110+
return framework::OpKernelType(
111+
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
112+
ctx.device_context());
113+
}
114+
};
115+
116+
} // namespace operators
117+
} // namespace paddle
118+
119+
namespace ops = paddle::operators;
120+
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
121+
paddle::framework::DefaultGradOpDescMaker<true>);
122+
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
123+
REGISTER_OP_CPU_KERNEL(
124+
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,
125+
ops::FSPOpKernel<paddle::platform::CPUDeviceContext, double>);
126+
REGISTER_OP_CPU_KERNEL(
127+
fsp_grad, ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, float>,
128+
ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/fsp_op.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/fsp_op.h"
17+
18+
namespace ops = paddle::operators;
19+
namespace plat = paddle::platform;
20+
REGISTER_OP_CUDA_KERNEL(fsp, ops::FSPOpKernel<plat::CUDADeviceContext, float>,
21+
ops::FSPOpKernel<plat::CUDADeviceContext, double>);
22+
REGISTER_OP_CUDA_KERNEL(fsp_grad,
23+
ops::FSPGradOpKernel<plat::CUDADeviceContext, float>,
24+
ops::FSPGradOpKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/fsp_op.h

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/math_function.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename DeviceContext, typename T>
26+
class FSPOpKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& context) const override {
29+
auto* x = context.Input<Tensor>("X");
30+
auto* y = context.Input<Tensor>("Y");
31+
auto* output = context.Output<Tensor>("Out");
32+
output->mutable_data<T>(context.GetPlace());
33+
auto x_dims = x->dims();
34+
auto y_dims = y->dims();
35+
36+
auto batch_size = x_dims[0];
37+
auto x_channel = x_dims[1];
38+
auto y_channel = y_dims[1];
39+
auto height = x_dims[2];
40+
auto width = x_dims[3];
41+
42+
auto blas = math::GetBlas<DeviceContext, T>(context);
43+
44+
math::MatDescriptor x_mat_desc;
45+
x_mat_desc.height_ = x_channel;
46+
x_mat_desc.width_ = height * width;
47+
x_mat_desc.batch_size_ = batch_size;
48+
x_mat_desc.stride_ = x_channel * height * width;
49+
50+
math::MatDescriptor y_mat_desc;
51+
y_mat_desc.height_ = height * width;
52+
y_mat_desc.width_ = y_channel;
53+
y_mat_desc.batch_size_ = batch_size;
54+
y_mat_desc.stride_ = y_channel * height * width;
55+
y_mat_desc.trans_ = true;
56+
57+
blas.MatMul(*x, x_mat_desc, *y, y_mat_desc,
58+
static_cast<T>(1.0 / (height * width)), output,
59+
static_cast<T>(0.0));
60+
}
61+
};
62+
63+
template <typename DeviceContext, typename T>
64+
class FSPGradOpKernel : public framework::OpKernel<T> {
65+
public:
66+
void Compute(const framework::ExecutionContext& context) const override {
67+
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
68+
auto* d_y = context.Output<Tensor>(framework::GradVarName("Y"));
69+
if (d_x == nullptr && d_y == nullptr) {
70+
return;
71+
}
72+
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
73+
auto d_out_dims = d_out->dims();
74+
auto batch_size = d_out_dims[0];
75+
auto x_channel = d_out_dims[1];
76+
auto y_channel = d_out_dims[2];
77+
int64_t h = 0;
78+
int64_t w = 0;
79+
80+
auto blas = math::GetBlas<DeviceContext, T>(context);
81+
math::SetConstant<DeviceContext, T> set_zero;
82+
if (d_x != nullptr) {
83+
d_x->mutable_data<T>(context.GetPlace());
84+
set_zero(context.template device_context<DeviceContext>(), d_x,
85+
static_cast<T>(0));
86+
auto* y = context.Input<Tensor>("Y");
87+
auto y_dims = y->dims();
88+
h = y_dims[2];
89+
w = y_dims[3];
90+
91+
math::MatDescriptor d_out_mat_desc;
92+
d_out_mat_desc.height_ = x_channel;
93+
d_out_mat_desc.width_ = y_channel;
94+
d_out_mat_desc.batch_size_ = batch_size;
95+
d_out_mat_desc.stride_ = x_channel * y_channel;
96+
97+
math::MatDescriptor y_mat_desc;
98+
y_mat_desc.height_ = y_channel;
99+
y_mat_desc.width_ = h * w;
100+
y_mat_desc.batch_size_ = batch_size;
101+
y_mat_desc.stride_ = y_channel * h * w;
102+
103+
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
104+
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
105+
}
106+
107+
if (d_y != nullptr) {
108+
d_y->mutable_data<T>(context.GetPlace());
109+
set_zero(context.template device_context<DeviceContext>(), d_y,
110+
static_cast<T>(0));
111+
auto* x = context.Input<Tensor>("X");
112+
auto x_dims = x->dims();
113+
h = x_dims[2];
114+
w = x_dims[3];
115+
116+
math::MatDescriptor d_out_mat_desc;
117+
d_out_mat_desc.height_ = y_channel;
118+
d_out_mat_desc.width_ = x_channel;
119+
d_out_mat_desc.batch_size_ = batch_size;
120+
d_out_mat_desc.stride_ = x_channel * y_channel;
121+
d_out_mat_desc.trans_ = true;
122+
123+
math::MatDescriptor x_mat_desc;
124+
x_mat_desc.height_ = x_channel;
125+
x_mat_desc.width_ = h * w;
126+
x_mat_desc.batch_size_ = batch_size;
127+
x_mat_desc.stride_ = x_channel * h * w;
128+
129+
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
130+
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
131+
}
132+
}
133+
};
134+
135+
} // namespace operators
136+
} // namespace paddle

python/paddle/fluid/contrib/slim/core/compressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __init__(self,
271271
self.eval_reader = eval_reader
272272
self.teacher_graphs = []
273273
for teacher in teacher_programs:
274-
self.teacher_graphs.append(ImitationGraph(teacher, scope=scope))
274+
self.teacher_graphs.append(GraphWrapper(teacher))
275275

276276
self.checkpoint = None
277277
self.checkpoint_path = checkpoint_path

python/paddle/fluid/contrib/slim/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..prune import *
2020
from ..quantization import *
2121
from .strategy import *
22+
from ..distillation import *
2223

2324
__all__ = ['ConfigFactory']
2425
"""This factory is used to create instances by loading and parsing configure file with yaml format.

python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py renamed to python/paddle/fluid/contrib/slim/distillation/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,3 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from . import distiller
16+
from .distiller import *
17+
from . import distillation_strategy
18+
from .distillation_strategy import *
19+
20+
__all__ = distiller.__all__
21+
__all__ += distillation_strategy.__all__

0 commit comments

Comments
 (0)