Skip to content

Commit 8cda7b3

Browse files
committed
Merge remote-tracking branch 'ups/develop' into fea/jit/act
test=develop
2 parents e2d6edd + 64f7516 commit 8cda7b3

29 files changed

+774
-75
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates',
128128
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
129129
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
130130
paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
131+
paddle.fluid.layers.selu ArgSpec(args=['x', 'scale', 'alpha', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
131132
paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
132133
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
133134
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pass_library(seq_concat_fc_fuse_pass inference)
4141
pass_library(multi_batch_merge_pass base)
4242
pass_library(conv_bn_fuse_pass inference)
4343
pass_library(seqconv_eltadd_relu_fuse_pass inference)
44+
pass_library(is_test_pass base)
4445
if(WITH_MKLDNN)
4546
pass_library(mkldnn_placement_pass base)
4647
pass_library(depthwise_conv_mkldnn_pass base)
@@ -62,6 +63,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
6263
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
6364
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
6465
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
66+
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
6567
if (WITH_MKLDNN)
6668
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
6769
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/is_test_pass.h"
16+
#include <string>
17+
#include <utility>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
24+
std::unique_ptr<ir::Graph> graph) const {
25+
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
26+
"for activations and pooling.";
27+
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
28+
"softshrink", "exp", "brelu",
29+
"pow", "leaky_relu", "stanh",
30+
"relu", "tanh", "tanh_shrink",
31+
"sqrt", "abs", "ceil",
32+
"elu", "floor", "cos",
33+
"sin", "round", "reciprocal",
34+
"hard_shrink", "hard_sigmoid", "relu6",
35+
"soft_relu", "swish", "thresholded_relu",
36+
"log", "square", "softplus",
37+
"softsign"};
38+
for (const Node* n : graph->Nodes()) {
39+
if (n->IsOp()) {
40+
auto* op = n->Op();
41+
if (op->HasAttr("is_test")) {
42+
op->SetAttr("is_test", true);
43+
} else if (std::find(begin(op_list), end(op_list), op->Type()) !=
44+
end(op_list)) {
45+
op->MutableAttrMap()->insert(
46+
std::pair<std::string, Attribute>("is_test", true));
47+
}
48+
}
49+
}
50+
return graph;
51+
}
52+
53+
} // namespace ir
54+
} // namespace framework
55+
} // namespace paddle
56+
57+
REGISTER_PASS(is_test_pass, paddle::framework::ir::IsTestPass);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/pass.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
class IsTestPass : public Pass {
24+
protected:
25+
std::unique_ptr<ir::Graph> ApplyImpl(
26+
std::unique_ptr<ir::Graph> graph) const override;
27+
};
28+
29+
} // namespace ir
30+
} // namespace framework
31+
} // namespace paddle
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/is_test_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
enum class ISTEST_STATE { FALSE, TRUE, UNSET };
24+
25+
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
26+
const std::vector<std::string>& inputs,
27+
const std::vector<std::string>& outputs, bool use_mkldnn = false,
28+
ISTEST_STATE is_test = ISTEST_STATE::UNSET) {
29+
auto* op = prog->MutableBlock(0)->AppendOp();
30+
op->SetType(type);
31+
op->SetAttr("name", name);
32+
op->SetInput("X", inputs);
33+
op->SetOutput("Out", outputs);
34+
op->SetAttr("use_mkldnn", use_mkldnn);
35+
if (is_test == ISTEST_STATE::UNSET)
36+
op->MutableAttrMap()->erase("is_test");
37+
else if (is_test == ISTEST_STATE::FALSE)
38+
op->SetAttr("is_test", false);
39+
else
40+
op->SetAttr("is_test", true);
41+
}
42+
43+
// a->pool2d->b
44+
// b->relu->c
45+
// c,weights1)->conv2d->d
46+
//
47+
// d->pool2d->e
48+
// e->hard_sigmoid->f
49+
// (f,weights2)->conv2d->g
50+
//
51+
// g->pool2d->h
52+
// h->tanh->i
53+
// (i,weights3)->conv2d->j
54+
ProgramDesc BuildProgramDesc() {
55+
ProgramDesc prog;
56+
for (auto& v :
57+
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h", "i",
58+
"j", "weights1", "weights2", "weights3"})) {
59+
auto* var = prog.MutableBlock(0)->Var(v);
60+
var->SetType(proto::VarType::SELECTED_ROWS);
61+
if (v == "weights1" || v == "weights2" || v == "weights3") {
62+
var->SetPersistable(true);
63+
}
64+
}
65+
66+
SetOp(&prog, "pool2d", "pooling1", std::vector<std::string>({"a"}),
67+
std::vector<std::string>({"b"}), true, ISTEST_STATE::TRUE);
68+
SetOp(&prog, "relu", "activation1", std::vector<std::string>({"b"}),
69+
std::vector<std::string>({"c"}), true, ISTEST_STATE::TRUE);
70+
SetOp(&prog, "conv2d", "conv1", std::vector<std::string>({"c", "weights1"}),
71+
std::vector<std::string>({"d"}), true, ISTEST_STATE::TRUE);
72+
73+
SetOp(&prog, "pool2d", "pooling2", std::vector<std::string>({"d"}),
74+
std::vector<std::string>({"e"}), false, ISTEST_STATE::FALSE);
75+
SetOp(&prog, "hard_sigmoid", "activation2", std::vector<std::string>({"e"}),
76+
std::vector<std::string>({"f"}), false, ISTEST_STATE::FALSE);
77+
SetOp(&prog, "conv2d", "conv2", std::vector<std::string>({"f", "weights2"}),
78+
std::vector<std::string>({"g"}), false, ISTEST_STATE::FALSE);
79+
80+
SetOp(&prog, "pool2d", "pooling3", std::vector<std::string>({"g"}),
81+
std::vector<std::string>({"h"}), false, ISTEST_STATE::UNSET);
82+
SetOp(&prog, "tanh", "activation3", std::vector<std::string>({"h"}),
83+
std::vector<std::string>({"i"}), true, ISTEST_STATE::UNSET);
84+
SetOp(&prog, "conv2d", "conv3", std::vector<std::string>({"i", "weights3"}),
85+
std::vector<std::string>({"j"}), false, ISTEST_STATE::UNSET);
86+
87+
return prog;
88+
}
89+
90+
TEST(IsTestPass, basic) {
91+
auto prog = BuildProgramDesc();
92+
93+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
94+
95+
auto pass = PassRegistry::Instance().Get("is_test_pass");
96+
97+
graph = pass->Apply(std::move(graph));
98+
99+
for (auto* node : graph->Nodes()) {
100+
if (node->IsOp()) {
101+
auto* op = node->Op();
102+
auto op_name = boost::get<std::string>(op->GetAttr("name"));
103+
if (op_name == "conv3") {
104+
ASSERT_FALSE(op->HasAttr("is_test"));
105+
} else {
106+
ASSERT_TRUE(op->HasAttr("is_test"));
107+
EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test")));
108+
}
109+
}
110+
}
111+
}
112+
113+
} // namespace ir
114+
} // namespace framework
115+
} // namespace paddle
116+
117+
USE_PASS(is_test_pass);

paddle/fluid/inference/api/paddle_pass_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class CpuPassStrategy : public PassStrategy {
8686
"fc_fuse_pass", //
8787
"conv_bn_fuse_pass", //
8888
"conv_eltwiseadd_bn_fuse_pass", //
89+
"is_test_pass", //
8990
});
9091
}
9192

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_te
7878
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
7979
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz")
8080

81+
# mobilenet with depthwise_conv op
82+
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet
83+
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz")
84+
8185
# anakin
8286
if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
8387
# anakin rnn1

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ class MKLDNNActivationGradKernel
7171
diff_y->format() != memory::format::format_undef,
7272
"Wrong layout/format set for Input OutGrad tensor");
7373

74+
PADDLE_ENFORCE(
75+
!ctx.Attr<bool>("is_test"),
76+
"is_test attribute should be set to False in training phase.");
77+
7478
Functor functor;
7579

7680
auto attrs = functor.GetAttrs();
@@ -115,11 +119,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
115119
const std::string key_fwd = key_with_layout + "@eltwise_fwd";
116120
const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";
117121

122+
bool is_test = ctx.Attr<bool>("is_test");
123+
118124
// save input data and layout to be referred in backward path
119125
auto p_src_data = std::make_shared<const T *>(x_data);
120-
dev_ctx.SetBlob(key_src_data, p_src_data);
121126
auto p_src_layout = std::make_shared<memory::format>(src_format);
122-
dev_ctx.SetBlob(key_src_layout, p_src_layout);
127+
if (!is_test) {
128+
dev_ctx.SetBlob(key_src_data, p_src_data);
129+
dev_ctx.SetBlob(key_src_layout, p_src_layout);
130+
}
123131

124132
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
125133
dev_ctx.GetBlob(key_fwd));
@@ -136,14 +144,17 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
136144
dev_ctx.SetBlob(key_src_mem, src_memory);
137145

138146
// create primitive descriptor for activation forward and save it
147+
auto mkldnn_forward_prop_kind = is_test
148+
? mkldnn::prop_kind::forward_inference
149+
: mkldnn::prop_kind::forward_training;
139150
auto forward_desc = mkldnn::eltwise_forward::desc(
140-
mkldnn::prop_kind::forward_training, algorithm,
151+
mkldnn_forward_prop_kind, algorithm,
141152
src_memory->get_primitive_desc().desc(), alpha, beta);
142153
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
143154
forward_desc, mkldnn_engine);
144155

145156
// save prim desc into global device context to be referred in backward path
146-
dev_ctx.SetBlob(key_fwd_pd, forward_pd);
157+
if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd);
147158

148159
// create mkldnn memory for output y
149160
dst_memory =

paddle/fluid/operators/activation_op.cc

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,23 @@ namespace operators {
2222

2323
using paddle::framework::Tensor;
2424

25-
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
26-
class OP_NAME##OpMaker \
27-
: public ::paddle::framework::OpProtoAndCheckerMaker { \
28-
public: \
29-
void Make() override { \
30-
AddInput("X", "Input of " #OP_NAME " operator"); \
31-
AddOutput("Out", "Output of " #OP_NAME " operator"); \
32-
AddAttr<bool>("use_mkldnn", \
33-
"(bool, default false) Only used in mkldnn kernel") \
34-
.SetDefault(false); \
35-
AddComment(#OP_COMMENT); \
36-
} \
25+
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
26+
class OP_NAME##OpMaker \
27+
: public ::paddle::framework::OpProtoAndCheckerMaker { \
28+
public: \
29+
void Make() override { \
30+
AddInput("X", "Input of " #OP_NAME " operator"); \
31+
AddOutput("Out", "Output of " #OP_NAME " operator"); \
32+
AddAttr<bool>("use_mkldnn", \
33+
"(bool, default false) Only used in mkldnn kernel") \
34+
.SetDefault(false); \
35+
AddAttr<bool>( \
36+
"is_test", \
37+
"(bool, default false) Set to true for inference only, false " \
38+
"for training. Some layers may run faster when this is true.") \
39+
.SetDefault(false); \
40+
AddComment(#OP_COMMENT); \
41+
} \
3742
}
3843

3944
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
@@ -269,7 +274,7 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
269274
:strong:`Softshrink Activation Operator`
270275
271276
.. math::
272-
out = \begin{cases}
277+
out = \begin{cases}
273278
x - \lambda, \text{if } x > \lambda \\
274279
x + \lambda, \text{if } x < -\lambda \\
275280
0, \text{otherwise}
@@ -435,7 +440,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
435440
AddComment(R"DOC(
436441
HardSigmoid Activation Operator.
437442
438-
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
443+
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
439444
which is much faster than sigmoid.
440445
441446
$out = \max(0, \min(1, slope * x + shift))$

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ class BatchNormOp : public framework::OperatorWithKernel {
113113
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
114114
public:
115115
void Make() override {
116-
AddAttr<bool>("is_test", "").SetDefault(false);
116+
AddAttr<bool>("is_test",
117+
"(bool, default false) Set to true for inference only, false "
118+
"for training. Some layers may run faster when this is true.")
119+
.SetDefault(false);
117120
AddAttr<float>("momentum", "").SetDefault(0.9);
118121
AddAttr<float>("epsilon", "")
119122
.SetDefault(1e-5)

0 commit comments

Comments
 (0)