Skip to content

Commit e5bf861

Browse files
committed
Merge branch 'develop' of https://github.com/paddlepaddle/paddle into add_trt_plugin
test=develop
2 parents d38fd6a + 38f499d commit e5bf861

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+2361
-223
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
4141
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
4242
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
4343
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
44+
option(WITH_NGRAPH "Compile PaddlePaddle with nGraph support." OFF)
4445
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
4546
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
4647
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
@@ -103,6 +104,8 @@ if(ANDROID OR IOS)
103104
"Disable RDMA when cross-compiling for Android and iOS" FORCE)
104105
set(WITH_MKL OFF CACHE STRING
105106
"Disable MKL when cross-compiling for Android and iOS" FORCE)
107+
set(WITH_NGRAPH OFF CACHE STRING
108+
"Disable nGraph when cross-compiling for Android and iOS" FORCE)
106109
set(WITH_GOLANG OFF CACHE STRING
107110
"Disable golang when cross-compiling for Android and iOS" FORCE)
108111

@@ -171,6 +174,7 @@ include(external/protobuf) # download, build, install protobuf
171174
include(external/python) # download, build, install python
172175
include(external/openblas) # download, build, install openblas
173176
include(external/mkldnn) # download, build, install mkldnn
177+
include(external/ngraph) # download, build, install nGraph
174178
include(external/swig) # download, build, install swig
175179
include(external/boost) # download boost
176180
include(external/any) # download libn::any

cmake/external/mkldnn.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
3737
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
3838

3939
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
40-
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include mkldnn.h
4140

4241
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
4342
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})

cmake/external/ngraph.cmake

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
add_library(ngraph INTERFACE)
16+
17+
IF(WIN32 OR APPLE)
18+
MESSAGE(WARNING
19+
"Windows or Mac is not supported with nGraph in Paddle yet."
20+
"Force WITH_NGRAPH=OFF")
21+
SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph in Windows and MacOS" FORCE)
22+
ENDIF()
23+
24+
IF(${WITH_NGRAPH} AND NOT ${WITH_MKLDNN})
25+
MESSAGE(WARNING
26+
"nGraph needs mkl-dnn to be enabled."
27+
"Force WITH_NGRAPH=OFF")
28+
SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph if mkl-dnn is disabled" FORCE)
29+
ENDIF()
30+
31+
IF(NOT ${WITH_NGRAPH})
32+
return()
33+
ENDIF()
34+
35+
INCLUDE(ExternalProject)
36+
37+
SET(NGRAPH_PROJECT "extern_ngraph")
38+
SET(NGRAPH_VERSION "0.9")
39+
SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
40+
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
41+
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
42+
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
43+
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION})
44+
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
45+
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
46+
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
47+
48+
ExternalProject_Add(
49+
${NGRAPH_PROJECT}
50+
${EXTERNAL_PROJECT_LOG_ARGS}
51+
DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT}
52+
GIT_REPOSITORY ${NGRAPH_GIT_REPO}
53+
GIT_TAG ${NGRAPH_GIT_TAG}
54+
PREFIX ${NGRAPH_SOURCES_DIR}
55+
UPDATE_COMMAND ""
56+
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
57+
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
58+
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
59+
CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE
60+
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE
61+
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
62+
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR}
63+
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib
64+
)
65+
66+
if(UNIX AND NOT APPLE)
67+
include(GNUInstallDirs)
68+
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
69+
else()
70+
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib)
71+
endif()
72+
MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}")
73+
74+
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
75+
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
76+
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
77+
78+
# Workaround for nGraph expecting mklml to be in mkldnn install directory.
79+
ExternalProject_Add_Step(
80+
${NGRAPH_PROJECT}
81+
PrepareMKL
82+
COMMAND ${CMAKE_COMMAND} -E create_symlink ${MKLML_LIB} ${MKLDNN_INSTALL_DIR}/lib/libmklml_intel.so
83+
COMMAND ${CMAKE_COMMAND} -E create_symlink ${MKLML_IOMP_LIB} ${MKLDNN_INSTALL_DIR}/lib/libiomp5.so
84+
DEPENDEES download
85+
DEPENDERS configure
86+
)
87+
88+
add_dependencies(ngraph ${NGRAPH_PROJECT})
89+
target_compile_definitions(ngraph INTERFACE -DPADDLE_WITH_NGRAPH)
90+
target_include_directories(ngraph INTERFACE ${NGRAPH_INC_DIR})
91+
target_link_libraries(ngraph INTERFACE ${NGRAPH_SHARED_LIB})
92+
LIST(APPEND external_project_dependencies ngraph)

cmake/external/protobuf.cmake

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,66 +30,61 @@ UNSET_VAR(PROTOBUF_LITE_LIBRARY)
3030
UNSET_VAR(PROTOBUF_LIBRARY)
3131
UNSET_VAR(PROTOBUF_INCLUDE_DIR)
3232
UNSET_VAR(Protobuf_PROTOC_EXECUTABLE)
33+
function(protobuf_generate_python SRCS)
34+
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
35+
if(NOT ARGN)
36+
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
37+
return()
38+
endif()
3339

34-
if(NOT COMMAND protobuf_generate_python) # before cmake 3.4, protobuf_genrerate_python is not defined.
35-
function(protobuf_generate_python SRCS)
36-
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
37-
if(NOT ARGN)
38-
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
39-
return()
40-
endif()
41-
42-
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
43-
# Create an include path for each file specified
44-
foreach(FIL ${ARGN})
45-
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
46-
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
47-
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
48-
if(${_contains_already} EQUAL -1)
49-
list(APPEND _protobuf_include_path -I ${ABS_PATH})
50-
endif()
51-
endforeach()
52-
else()
53-
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
54-
endif()
55-
56-
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
57-
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
58-
endif()
59-
60-
if(DEFINED Protobuf_IMPORT_DIRS)
61-
foreach(DIR ${Protobuf_IMPORT_DIRS})
62-
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
63-
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
64-
if(${_contains_already} EQUAL -1)
65-
list(APPEND _protobuf_include_path -I ${ABS_PATH})
66-
endif()
67-
endforeach()
68-
endif()
69-
70-
set(${SRCS})
40+
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
41+
# Create an include path for each file specified
7142
foreach(FIL ${ARGN})
7243
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
73-
get_filename_component(FIL_WE ${FIL} NAME_WE)
74-
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
75-
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
76-
if(FIL_DIR)
77-
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
78-
endif()
44+
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
45+
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
46+
if(${_contains_already} EQUAL -1)
47+
list(APPEND _protobuf_include_path -I ${ABS_PATH})
7948
endif()
49+
endforeach()
50+
else()
51+
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
52+
endif()
53+
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
54+
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
55+
endif()
8056

81-
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
82-
add_custom_command(
83-
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
84-
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
85-
DEPENDS ${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE}
86-
COMMENT "Running Python protocol buffer compiler on ${FIL}"
87-
VERBATIM )
57+
if(DEFINED Protobuf_IMPORT_DIRS)
58+
foreach(DIR ${Protobuf_IMPORT_DIRS})
59+
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
60+
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
61+
if(${_contains_already} EQUAL -1)
62+
list(APPEND _protobuf_include_path -I ${ABS_PATH})
63+
endif()
8864
endforeach()
65+
endif()
8966

90-
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
91-
endfunction()
92-
endif()
67+
set(${SRCS})
68+
foreach(FIL ${ARGN})
69+
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
70+
get_filename_component(FIL_WE ${FIL} NAME_WE)
71+
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
72+
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
73+
if(FIL_DIR)
74+
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
75+
endif()
76+
endif()
77+
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
78+
add_custom_command(
79+
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
80+
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
81+
DEPENDS ${ABS_FIL} ${PROTOBUF_PROTOC_EXECUTABLE}
82+
COMMENT "Running Python protocol buffer compiler on ${FIL}"
83+
VERBATIM )
84+
endforeach()
85+
86+
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
87+
endfunction()
9388

9489
# Print and set the protobuf library information,
9590
# finish this cmake process and exit from this file.
@@ -126,6 +121,7 @@ macro(PROMPT_PROTOBUF_LIB)
126121
# FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`.
127122
# make `protobuf_generate_cpp` happy.
128123
SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE})
124+
129125
FOREACH(dep ${protobuf_DEPS})
130126
ADD_DEPENDENCIES(protobuf ${dep})
131127
ADD_DEPENDENCIES(protobuf_lite ${dep})

paddle/fluid/API.spec

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
105105
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
106-
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False))
106+
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False))
107107
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
108108
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
109109
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
@@ -179,10 +179,12 @@ paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], vara
179179
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
180180
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
181181
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
182+
paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,))
182183
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
183184
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
184185
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
185186
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
187+
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
186188
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
187189
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
188190
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
@@ -201,6 +203,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'],
201203
paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None))
202204
paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None))
203205
paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None)
206+
paddle.fluid.layers.tensor_array_to_tensor ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
204207
paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None))
205208
paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,))
206209
paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,))
@@ -271,6 +274,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k
271274
paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
272275
paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
273276
paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False))
277+
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None))
274278
paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False))
275279
paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
276280
paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None))

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,12 @@ struct HitGroup {
167167

168168
bool Match(Node *node, PDNode *pat) {
169169
if (nodes_.count(node)) {
170-
if (!roles.count(pat)) return false;
171-
return roles[pat] == node;
170+
if (roles.count(pat) && roles[pat] == node) return true;
171+
return false;
172+
} else {
173+
if (roles.count(pat) && roles[pat] != node) return false;
174+
return true;
172175
}
173-
return !roles.count(pat) || roles.at(pat) == node;
174176
}
175177

176178
void Register(Node *node, PDNode *pat) {
@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
198200
std::vector<GraphPatternDetector::subgraph_t> result;
199201
std::vector<HitGroup> init_groups;
200202
std::array<std::vector<HitGroup>, 2> bi_records;
201-
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
202203
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
203204
: pattern_.edges().front().first;
204205
if (!pdnodes2nodes_.count(first_pnode)) return result;
@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
228229
VLOG(80) << "check " << source->id() << " -- " << target->id();
229230
// TODO(Superjomn) add some prune strategies.
230231
for (const auto &group : pre_groups) {
231-
HitGroup new_group = group;
232-
if (IsNodesLink(source, target) &&
233-
new_group.Match(source, edge.first)) {
234-
new_group.Register(source, edge.first);
235-
if (new_group.Match(target, edge.second)) {
232+
if (IsNodesLink(source, target)) {
233+
HitGroup new_group = group;
234+
bool flag = new_group.Match(source, edge.first) &&
235+
new_group.Match(target, edge.second);
236+
if (flag) {
237+
new_group.Register(source, edge.first);
236238
new_group.Register(target, edge.second);
237239
cur_groups.push_back(new_group);
238240
// TODO(Superjomn) need to unique
@@ -261,14 +263,16 @@ GraphPatternDetector::DetectPatterns() {
261263
return result;
262264
}
263265

264-
bool GraphItemCMP(const std::pair<PDNode *, Node *> &a,
266+
struct GraphItemLessThan {
267+
bool operator()(const std::pair<PDNode *, Node *> &a,
265268
const std::pair<PDNode *, Node *> &b) {
266-
if (a.first != b.first) {
267-
return a.first < b.first;
268-
} else {
269-
return a.second < b.second;
269+
if (a.first != b.first) {
270+
return a.first < b.first;
271+
} else {
272+
return a.second < b.second;
273+
}
270274
}
271-
}
275+
};
272276

273277
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
274278
// see https://github.com/PaddlePaddle/Paddle/issues/13550
@@ -282,7 +286,7 @@ void GraphPatternDetector::UniquePatterns(
282286
for (auto &g : *subgraphs) {
283287
// Sort the items in the sub-graph, and transform to a string key.
284288
std::vector<std::pair<PDNode *, Node *>> sorted_keys(g.begin(), g.end());
285-
std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemCMP);
289+
std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemLessThan());
286290
std::stringstream ss;
287291
for (auto &item : sorted_keys) {
288292
ss << item.first << ":" << item.second;

0 commit comments

Comments
 (0)