Skip to content

Commit 5a38930

Browse files
committed
test=develop
2 parents ac2eba4 + 9517a45 commit 5a38930

30 files changed

+553
-172
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33

44
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
5-
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html)
6-
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html)
5+
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.0/getstarted/index_en.html)
6+
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html)
77
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
88
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
99

@@ -19,17 +19,17 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
1919
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
2020

2121

22-
### Latest PaddlePaddle Release: [Fluid 1.0.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0)
22+
### Latest PaddlePaddle Release: [Fluid 1.0.1](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0)
2323
### Install Latest Stable Release:
2424
```
2525
# Linux CPU
2626
pip install paddlepaddle
2727
# Linux GPU cuda9cudnn7
2828
pip install paddlepaddle-gpu
2929
# Linux GPU cuda8cudnn7
30-
pip install paddlepaddle-gpu==0.15.0.post87
30+
pip install paddlepaddle-gpu==1.0.1.post87
3131
# Linux GPU cuda8cudnn5
32-
pip install paddlepaddle-gpu==0.15.0.post85
32+
pip install paddlepaddle-gpu==1.0.1.post85
3333
3434
# For installation on other platform, refer to http://paddlepaddle.org/
3535
```

cmake/generic.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ function(cc_test TARGET_NAME)
311311
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
312312
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
313313
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
314+
# No unit test should exceed 10 minutes.
315+
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
314316
endif()
315317
endfunction(cc_test)
316318

@@ -629,6 +631,8 @@ function(py_test TARGET_NAME)
629631
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
630632
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
631633
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
634+
# No unit test should exceed 10 minutes.
635+
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
632636
endif()
633637
endfunction()
634638

paddle/fluid/API.spec

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None
6161
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
6262
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
6363
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
64-
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None))
64+
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None))
6565
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
6666
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
6767
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
68-
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False))
69-
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None))
68+
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
69+
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
7070
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
7171
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
7272
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
@@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
9797
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
9898
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
9999
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
100-
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'], varargs=None, keywords=None, defaults=(None, None, None, None))
101-
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None))
100+
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None))
101+
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
102102
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function(pass_library TARGET DEST)
1010
set(oneValueArgs "")
1111
set(multiValueArgs SRCS DEPS)
1212
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
13-
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS})
13+
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
1414
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
1515
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
1616
message(STATUS "add pass ${TARGET} ${DEST}")
@@ -25,20 +25,22 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
2525
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
2626
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
2727
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
28+
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
2829

2930
pass_library(graph_to_program_pass base)
3031
pass_library(graph_viz_pass base)
3132
pass_library(fc_fuse_pass inference)
32-
if (WITH_MKLDNN)
33-
pass_library(conv_relu_mkldnn_fuse_pass inference)
34-
endif ()
3533
pass_library(attention_lstm_fuse_pass inference)
3634
pass_library(infer_clean_graph_pass inference)
3735
pass_library(fc_lstm_fuse_pass inference)
3836
pass_library(embedding_fc_lstm_fuse_pass inference)
3937
pass_library(fc_gru_fuse_pass inference)
4038
pass_library(seq_concat_fc_fuse_pass inference)
4139
pass_library(conv_bn_fuse_pass inference)
40+
if(WITH_MKLDNN)
41+
pass_library(mkldnn_placement_pass base)
42+
pass_library(conv_relu_mkldnn_fuse_pass inference)
43+
endif()
4244

4345
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
4446

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
262262
std::unordered_set<std::string> specified_vars({"data_lod_attention",
263263
"cell_init", "hidden_init",
264264
"data", "week", "minute"});
265-
int count = 0;
265+
size_t count = 0;
266266
for (auto* node : graph->Nodes()) {
267267
if (node->IsVar() && specified_vars.count(node->Name())) {
268268
++count;

paddle/fluid/framework/ir/conv_bn_fuse_pass.cc

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,21 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
126126
// conv, batch_norm,
127127
// conv_weight, conv_out,
128128
// bn_scale, bn_bias, bn_mean, bn_variance,
129-
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance
129+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
130+
// bn_saved_variance
130131
GET_CONV_BN_NODES(conv_bn_pattern);
131132

133+
// check if fuse can be done and if MKL-DNN should be used
134+
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
135+
if (fuse_option == DO_NOT_FUSE) {
136+
VLOG(3) << "do not perform conv+bn fuse";
137+
return;
138+
}
139+
132140
// Create eltwise_y (conv bias) variable
133141
VarDesc eltwise_y_in_desc(
134142
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
143+
eltwise_y_in_desc.SetPersistable(true);
135144
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
136145
auto* eltwise_y_in_tensor =
137146
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
@@ -151,27 +160,59 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
151160
*bn_mean, *bn_variance, eltwise_y_in_tensor,
152161
epsilon);
153162

154-
// Create an elementwise add node
155-
OpDesc desc;
156-
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
157-
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
158-
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
159-
desc.SetType("elementwise_add");
160-
desc.SetAttr("axis", 1);
161-
bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn"));
162-
desc.SetAttr("use_mkldnn", a);
163-
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
164-
165-
GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance,
166-
batch_norm, bn_mean_out, bn_variance_out,
167-
bn_saved_mean, bn_saved_variance});
168-
169-
PADDLE_ENFORCE(subgraph.count(conv_input));
170-
IR_NODE_LINK_TO(conv_out, eltwise_op);
171-
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
172-
IR_NODE_LINK_TO(eltwise_op, bn_out);
173-
174-
found_conv_bn_count++;
163+
// with MKL-DNN fuse conv+bn into conv with bias
164+
// without MKL-DNN fuse conv+bn into conv+elementwise_add
165+
if (fuse_option == FUSE_MKLDNN) {
166+
auto input_names = conv->Op()->InputNames();
167+
bool has_bias = std::find(input_names.begin(), input_names.end(),
168+
"Bias") != input_names.end();
169+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
170+
// reuse existing conv bias node
171+
auto conv_bias_names = conv->Op()->Input("Bias");
172+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
173+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
174+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
175+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
176+
eltwise_y_in_tensor->dims());
177+
178+
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
179+
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
180+
} else {
181+
// add new conv_bias node
182+
conv->Op()->SetInput(
183+
"Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
184+
IR_NODE_LINK_TO(eltwise_y_in_node, conv);
185+
}
186+
conv->Op()->SetOutput("Output",
187+
std::vector<std::string>({bn_out->Name()}));
188+
189+
GraphSafeRemoveNodes(
190+
graph.get(),
191+
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
192+
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
193+
194+
IR_NODE_LINK_TO(conv, bn_out);
195+
found_conv_bn_count++;
196+
} else { // fuse_option == FUSE_NATIVE
197+
// create an elementwise add node.
198+
OpDesc desc;
199+
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
200+
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
201+
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
202+
desc.SetType("elementwise_add");
203+
desc.SetAttr("axis", 1);
204+
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
205+
206+
GraphSafeRemoveNodes(
207+
graph.get(),
208+
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
209+
bn_variance_out, bn_saved_mean, bn_saved_variance});
210+
211+
IR_NODE_LINK_TO(conv_out, eltwise_op);
212+
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
213+
IR_NODE_LINK_TO(eltwise_op, bn_out);
214+
found_conv_bn_count++;
215+
}
175216
};
176217

177218
gpd(graph.get(), handler);
@@ -237,7 +278,6 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
237278
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
238279
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
239280

240-
PADDLE_ENFORCE(subgraph.count(conv_input));
241281
IR_NODE_LINK_TO(eltwise, bn_out);
242282

243283
found_conv_bn_count++;

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
4646
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
4747
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
4848

49+
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
50+
if (fuse_option == DO_NOT_FUSE) {
51+
VLOG(3) << "do not perform conv+relu fuse";
52+
return;
53+
}
54+
4955
// Transform Conv node into ConvReLU node.
5056
OpDesc* desc = conv->Op();
5157
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@ namespace paddle {
2020
namespace framework {
2121
namespace ir {
2222

23-
void SetOp(ProgramDesc* prog, const std::string& type,
23+
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
2424
const std::vector<std::string>& inputs,
25-
const std::vector<std::string>& outputs) {
25+
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
2626
auto* op = prog->MutableBlock(0)->AppendOp();
2727
op->SetType(type);
2828
if (type == "conv2d") {
29-
op->SetAttr("use_mkldnn", true);
29+
op->SetAttr("use_mkldnn", use_mkldnn);
30+
op->SetAttr("name", name);
3031
op->SetInput("Input", {inputs[0]});
3132
op->SetInput("Filter", {inputs[1]});
3233
op->SetInput("Bias", {inputs[2]});
3334
} else if (type == "relu") {
35+
op->SetAttr("use_mkldnn", use_mkldnn);
3436
op->SetInput("X", inputs);
3537
}
3638
op->SetOutput("Out", outputs);
@@ -43,22 +45,33 @@ void SetOp(ProgramDesc* prog, const std::string& type,
4345
ProgramDesc BuildProgramDesc() {
4446
ProgramDesc prog;
4547
for (auto& v :
46-
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
48+
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
49+
"h", "weights2", "bias2", "k", "l"})) {
4750
auto* var = prog.MutableBlock(0)->Var(v);
4851
var->SetType(proto::VarType::SELECTED_ROWS);
4952
if (v == "weights" || v == "bias") {
5053
var->SetPersistable(true);
5154
}
5255
}
5356

54-
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
57+
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
5558
std::vector<std::string>({"b"}));
56-
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
59+
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
5760
std::vector<std::string>({"c"}));
58-
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
59-
std::vector<std::string>({"f"}));
60-
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
61-
std::vector<std::string>({"g"}));
61+
// conv+relu, both with MKL-DNN
62+
SetOp(&prog, "conv2d", "conv1",
63+
std::vector<std::string>({"c", "weights", "bias"}),
64+
std::vector<std::string>({"f"}), true);
65+
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
66+
std::vector<std::string>({"g"}), true);
67+
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
68+
std::vector<std::string>({"h"}));
69+
// conv+relu, only one with MKL-DNN
70+
SetOp(&prog, "conv2d", "conv2",
71+
std::vector<std::string>({"h", "weights2", "bias2"}),
72+
std::vector<std::string>({"k"}), true);
73+
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
74+
std::vector<std::string>({"l"}));
6275

6376
return prog;
6477
}
@@ -88,10 +101,16 @@ TEST(ConvReLUFusePass, basic) {
88101
auto* op = node->Op();
89102
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
90103
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
91-
ASSERT_TRUE(op->HasAttr("fuse_relu"));
92-
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
93-
if (fuse_relu) {
94-
++conv_relu_count;
104+
// check if only "conv1" convolution is fused
105+
auto op_name = boost::get<std::string>(op->GetAttr("name"));
106+
if (op_name == "conv1") {
107+
ASSERT_TRUE(op->HasAttr("fuse_relu"));
108+
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
109+
if (fuse_relu) {
110+
++conv_relu_count;
111+
}
112+
} else if (op_name == "conv2") {
113+
ASSERT_FALSE(op->HasAttr("fuse_relu"));
95114
}
96115
}
97116
}

0 commit comments

Comments
 (0)