Skip to content

Commit f418f55

Browse files
committed
Merge branch 'develop' into develop_7a64d48f5_stack_opt (test=develop)
2 parents 03ccb9a + fd7e643 commit f418f55

File tree

302 files changed

+3771
-1402
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

302 files changed

+3771
-1402
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,6 @@ endif()
315315

316316
if (ON_INFER)
317317
message(STATUS "On inference mode, will take place some specific optimization.")
318-
add_definitions(-DPADDLE_ON_INFERENCE)
319318
else()
320319
#TODO(luotao), combine this warning with `make inference_lib_dist` command.
321320
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")

cmake/configure.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,7 @@ endif(WITH_GRPC)
218218
if(WITH_BRPC_RDMA)
219219
add_definitions(-DPADDLE_WITH_BRPC_RDMA)
220220
endif(WITH_BRPC_RDMA)
221+
222+
if(ON_INFER)
223+
add_definitions(-DPADDLE_ON_INFERENCE)
224+
endif(ON_INFER)

cmake/operators.cmake

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
set(PART_CUDA_KERNEL_FILES)
2+
function(op_library TARGET)
3+
# op_library is a function to create op library. The interface is same as
4+
# cc_library. But it handle split GPU/CPU code and link some common library
5+
# for ops.
6+
set(cc_srcs)
7+
set(cu_srcs)
8+
set(hip_cu_srcs)
9+
set(miopen_hip_cc_srcs)
10+
set(cu_cc_srcs)
11+
set(cudnn_cu_cc_srcs)
12+
set(CUDNN_FILE)
13+
set(mkldnn_cc_srcs)
14+
set(MKLDNN_FILE)
15+
set(op_common_deps operator op_registry math_function)
16+
set(options "")
17+
set(oneValueArgs "")
18+
set(multiValueArgs SRCS DEPS)
19+
set(pybind_flag 0)
20+
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
21+
"${multiValueArgs}" ${ARGN})
22+
23+
list(LENGTH op_library_SRCS op_library_SRCS_len)
24+
if (${op_library_SRCS_len} EQUAL 0)
25+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
26+
list(APPEND cc_srcs ${TARGET}.cc)
27+
endif()
28+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
29+
list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
30+
endif()
31+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
32+
list(APPEND cu_srcs ${TARGET}.cu)
33+
endif()
34+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
35+
set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
36+
${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
37+
list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
38+
endif()
39+
40+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
41+
list(APPEND hip_cu_srcs ${TARGET}.hip.cu)
42+
endif()
43+
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
44+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
45+
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
46+
endif()
47+
if(WITH_AMD_GPU)
48+
string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
49+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
50+
list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc)
51+
endif()
52+
endif()
53+
if(WITH_MKLDNN)
54+
string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
55+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc)
56+
list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc)
57+
endif()
58+
endif()
59+
else()
60+
foreach(src ${op_library_SRCS})
61+
if (${src} MATCHES ".*\\.hip.cu$")
62+
list(APPEND hip_cu_srcs ${src})
63+
elseif (${src} MATCHES ".*\\.cu$")
64+
list(APPEND cu_srcs ${src})
65+
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
66+
list(APPEND cudnn_cu_cc_srcs ${src})
67+
elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$")
68+
list(APPEND miopen_hip_cc_srcs ${src})
69+
elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
70+
list(APPEND mkldnn_cc_srcs ${src})
71+
elseif(${src} MATCHES ".*\\.cu.cc$")
72+
list(APPEND cu_cc_srcs ${src})
73+
elseif(${src} MATCHES ".*\\.cc$")
74+
list(APPEND cc_srcs ${src})
75+
else()
76+
message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
77+
endif()
78+
endforeach()
79+
endif()
80+
81+
list(LENGTH cc_srcs cc_srcs_len)
82+
if (${cc_srcs_len} EQUAL 0)
83+
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
84+
endif()
85+
if (WIN32)
86+
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
87+
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
88+
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
89+
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
90+
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
91+
return()
92+
endif()
93+
endforeach()
94+
endif(WIN32)
95+
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
96+
97+
list(LENGTH op_library_DEPS op_library_DEPS_len)
98+
if (${op_library_DEPS_len} GREATER 0)
99+
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
100+
endif()
101+
if (WITH_GPU)
102+
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
103+
${op_common_deps})
104+
elseif (WITH_AMD_GPU)
105+
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
106+
${op_common_deps})
107+
else()
108+
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
109+
${op_common_deps})
110+
endif()
111+
112+
# Define operators that don't need pybind here.
113+
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
114+
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
115+
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
116+
set(pybind_flag 1)
117+
endif()
118+
endforeach()
119+
120+
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
121+
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
122+
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
123+
file(READ ${TARGET}.cc TARGET_CONTENT)
124+
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
125+
string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}")
126+
if (one_register STREQUAL "")
127+
string(REPLACE "_op" "" TARGET "${TARGET}")
128+
else ()
129+
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
130+
string(REPLACE "," "" TARGET "${TARGET}")
131+
endif()
132+
133+
# pybind USE_NO_KERNEL_OP
134+
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
135+
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
136+
string(REPLACE "_op" "" TARGET "${TARGET}")
137+
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
138+
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
139+
set(pybind_flag 1)
140+
endif()
141+
142+
# pybind USE_CPU_ONLY_OP
143+
list(LENGTH cu_srcs cu_srcs_len)
144+
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
145+
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
146+
list(LENGTH hip_cu_srcs hip_cu_srcs_len)
147+
list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
148+
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
149+
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0)
150+
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
151+
set(pybind_flag 1)
152+
endif()
153+
154+
# pybind USE_OP_DEVICE_KERNEL for CUDNN
155+
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
156+
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
157+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
158+
endif()
159+
160+
# pybind USE_OP_DEVICE_KERNEL for MIOPEN
161+
if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
162+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
163+
endif()
164+
165+
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
166+
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
167+
# Append first implemented MKLDNN activation operator
168+
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
169+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
170+
else()
171+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
172+
endif()
173+
endif()
174+
175+
# pybind USE_OP
176+
if (${pybind_flag} EQUAL 0)
177+
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
178+
if(${TARGET} STREQUAL "activation")
179+
file(APPEND ${pybind_file} "USE_OP(relu);\n")
180+
elseif(${TARGET} STREQUAL "fake_dequantize")
181+
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
182+
elseif(${TARGET} STREQUAL "fake_quantize")
183+
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
184+
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
185+
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
186+
elseif(${TARGET} STREQUAL "fc")
187+
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
188+
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
189+
else()
190+
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
191+
endif()
192+
endif()
193+
endfunction()
194+
195+
196+
function(register_operators)
197+
set(options "")
198+
set(oneValueArgs "")
199+
set(multiValueArgs EXCLUDES DEPS)
200+
cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
201+
"${multiValueArgs}" ${ARGN})
202+
203+
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
204+
string(REPLACE "_mkldnn" "" OPS "${OPS}")
205+
string(REPLACE ".cc" "" OPS "${OPS}")
206+
list(REMOVE_DUPLICATES OPS)
207+
list(LENGTH register_operators_DEPS register_operators_DEPS_len)
208+
209+
foreach(src ${OPS})
210+
list(FIND register_operators_EXCLUDES ${src} _index)
211+
if (${_index} EQUAL -1)
212+
if (${register_operators_DEPS_len} GREATER 0)
213+
op_library(${src} DEPS ${register_operators_DEPS})
214+
else()
215+
op_library(${src})
216+
endif()
217+
endif()
218+
endforeach()
219+
endfunction()

paddle/fluid/API.spec

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized',
9393
paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None))
9494
paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None))
9595
paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,))
96-
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times'], varargs=None, keywords=None, defaults=(0, False))
96+
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False))
9797
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
9898
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
9999
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
100-
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None))
100+
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
101101
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
102102
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
@@ -128,6 +128,7 @@ paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates',
128128
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
129129
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
130130
paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
131+
paddle.fluid.layers.selu ArgSpec(args=['x', 'scale', 'alpha', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
131132
paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
132133
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
133134
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/framework/data_device_transform_test.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/lod_tensor.h"
1818
#include "paddle/fluid/framework/op_info.h"
1919
#include "paddle/fluid/framework/op_registry.h"
20-
#include "paddle/fluid/operators/elementwise_op_function.h"
20+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2121
#include "paddle/fluid/operators/math/math_function.h"
2222
#include "paddle/fluid/platform/device_context.h"
2323
#include "paddle/fluid/platform/init.h"

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pass_library(seq_concat_fc_fuse_pass inference)
4141
pass_library(multi_batch_merge_pass base)
4242
pass_library(conv_bn_fuse_pass inference)
4343
pass_library(seqconv_eltadd_relu_fuse_pass inference)
44+
pass_library(is_test_pass base)
4445
if(WITH_MKLDNN)
4546
pass_library(mkldnn_placement_pass base)
4647
pass_library(depthwise_conv_mkldnn_pass base)
@@ -62,6 +63,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
6263
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
6364
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
6465
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
66+
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
6567
if (WITH_MKLDNN)
6668
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
6769
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
5757
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
5858
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
5959
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
60+
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
6061
desc.SetType("fc");
6162
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
6263
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});

paddle/fluid/framework/ir/fc_fuse_pass_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
2929
if (type == "mul") {
3030
op->SetInput("X", {inputs[0]});
3131
op->SetInput("Y", {inputs[1]});
32+
op->SetAttr("x_num_col_dims", {1});
3233
} else if (type == "elementwise_add") {
3334
op->SetInput("X", inputs);
3435
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/is_test_pass.h"
16+
#include <string>
17+
#include <utility>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
24+
std::unique_ptr<ir::Graph> graph) const {
25+
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
26+
"for activations and pooling.";
27+
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
28+
"softshrink", "exp", "brelu",
29+
"pow", "leaky_relu", "stanh",
30+
"relu", "tanh", "tanh_shrink",
31+
"sqrt", "abs", "ceil",
32+
"elu", "floor", "cos",
33+
"sin", "round", "reciprocal",
34+
"hard_shrink", "hard_sigmoid", "relu6",
35+
"soft_relu", "swish", "thresholded_relu",
36+
"log", "square", "softplus",
37+
"softsign"};
38+
for (const Node* n : graph->Nodes()) {
39+
if (n->IsOp()) {
40+
auto* op = n->Op();
41+
if (op->HasAttr("is_test")) {
42+
op->SetAttr("is_test", true);
43+
} else if (std::find(begin(op_list), end(op_list), op->Type()) !=
44+
end(op_list)) {
45+
op->MutableAttrMap()->insert(
46+
std::pair<std::string, Attribute>("is_test", true));
47+
}
48+
}
49+
}
50+
return graph;
51+
}
52+
53+
} // namespace ir
54+
} // namespace framework
55+
} // namespace paddle
56+
57+
REGISTER_PASS(is_test_pass, paddle::framework::ir::IsTestPass);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/pass.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
class IsTestPass : public Pass {
24+
protected:
25+
std::unique_ptr<ir::Graph> ApplyImpl(
26+
std::unique_ptr<ir::Graph> graph) const override;
27+
};
28+
29+
} // namespace ir
30+
} // namespace framework
31+
} // namespace paddle

0 commit comments

Comments
 (0)