Skip to content

Commit 7a75cef

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_ut
2 parents 942bf06 + 27d6962 commit 7a75cef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+280
-262
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
6565
option(WITH_ANAKIN "Compile with Anakin library" OFF)
6666
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
6767
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
68+
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
6869

6970
# CMAKE_BUILD_TYPE
7071
if(NOT CMAKE_BUILD_TYPE)

cmake/cblas.cmake

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,20 @@ else()
8383
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
8484
endif()
8585

86-
find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
86+
if(WITH_SYSTEM_BLAS)
87+
find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
8788
${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS})
88-
find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
89+
find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
8990
${REFERENCE_CBLAS_LIB_SEARCH_PATHS})
9091

91-
if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
92-
set(CBLAS_FOUND ON)
93-
set(CBLAS_PROVIDER REFERENCE)
94-
set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR})
95-
set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY})
96-
add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
97-
message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
92+
if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
93+
set(CBLAS_FOUND ON)
94+
set(CBLAS_PROVIDER REFERENCE)
95+
set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR})
96+
set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY})
97+
add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
98+
message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
99+
endif()
98100
endif()
99101

100102
if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND)

paddle/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ if(NOT WITH_FLUID_ONLY)
77
add_subdirectory(legacy/parameter)
88

99
if(MOBILE_INFERENCE)
10-
add_subdirectory(capi)
10+
add_subdirectory(legacy/capi)
1111
else()
1212
add_subdirectory(legacy/pserver)
1313
add_subdirectory(trainer)
1414
add_subdirectory(scripts)
1515

1616
if(WITH_C_API)
17-
add_subdirectory(capi)
17+
add_subdirectory(legacy/capi)
1818
endif()
1919

2020
if(WITH_SWIG_PY)
21-
add_subdirectory(api)
21+
add_subdirectory(legacy/api)
2222
endif()
2323
endif()
2424
endif()

paddle/fluid/framework/op_registry.h

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,31 @@ class OpRegistry {
7676
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
7777
struct OpKernelRegistrarFunctor;
7878

79+
template <typename PlaceType, typename T, typename Func>
80+
inline void RegisterKernelClass(const char* op_type, const char* library_type,
81+
Func func) {
82+
std::string library(library_type);
83+
std::string data_layout = "ANYLAYOUT";
84+
if (library == "MKLDNN") {
85+
data_layout = "MKLDNNLAYOUT";
86+
}
87+
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
88+
StringToDataLayout(data_layout),
89+
StringToLibraryType(library_type));
90+
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
91+
}
92+
7993
template <typename PlaceType, size_t I, typename... KernelTypes>
8094
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
8195
using KERNEL_TYPE =
8296
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
8397

8498
void operator()(const char* op_type, const char* library_type) const {
8599
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
86-
std::string library(library_type);
87-
std::string data_layout = "ANYLAYOUT";
88-
if (library == "MKLDNN") {
89-
data_layout = "MKLDNNLAYOUT";
90-
}
91-
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
92-
StringToDataLayout(data_layout),
93-
StringToLibraryType(library_type));
94-
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
95-
100+
RegisterKernelClass<PlaceType, T>(
101+
op_type, library_type, [](const framework::ExecutionContext& ctx) {
102+
KERNEL_TYPE().Compute(ctx);
103+
});
96104
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
97105
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
98106
func;
@@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar {
116124
}
117125
};
118126

127+
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
128+
struct OpKernelRegistrarFunctorEx;
129+
130+
template <typename PlaceType, typename... DataTypeAndKernelType>
131+
class OpKernelRegistrarEx : public Registrar {
132+
public:
133+
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
134+
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
135+
func;
136+
func(op_type, library_type);
137+
}
138+
};
139+
140+
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
141+
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
142+
DataTypeAndKernelType...> {
143+
void operator()(const char* op_type, const char* library_type) const {}
144+
};
145+
146+
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
147+
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
148+
DataTypeAndKernelType...> {
149+
using Functor =
150+
typename std::tuple_element<I + 1,
151+
std::tuple<DataTypeAndKernelType...>>::type;
152+
using T =
153+
typename std::tuple_element<I,
154+
std::tuple<DataTypeAndKernelType...>>::type;
155+
156+
void operator()(const char* op_type, const char* library_type) const {
157+
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
158+
159+
constexpr auto size =
160+
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
161+
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
162+
DataTypeAndKernelType...>
163+
func;
164+
func(op_type, library_type);
165+
}
166+
};
167+
119168
/**
120169
* check if MACRO is used in GLOBAL NAMESPACE.
121170
*/
@@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar {
174223
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
175224
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
176225

226+
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
227+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
228+
__reg_op_kernel_##op_type##_##library_type##__, \
229+
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
230+
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
231+
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
232+
#library_type); \
233+
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
234+
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
235+
return 0; \
236+
}
237+
238+
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
239+
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
240+
__VA_ARGS__)
241+
242+
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
243+
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
244+
177245
/**
178246
* Macro to mark what Operator and Kernel
179247
* we will use and tell the compiler to

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
651651
dev_ctx = pool.Get(expected_kernel_key.place_);
652652
}
653653

654-
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
654+
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
655655

656656
if (!transfered_inplace_vars.empty()) {
657657
// there is inplace variable has been transfered.

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
347347

348348
class OperatorWithKernel : public OperatorBase {
349349
public:
350+
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
350351
using OpKernelMap =
351-
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
352-
OpKernelType::Hash>;
352+
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
353353

354354
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
355355
const VariableNameMap& outputs, const AttributeMap& attrs)

paddle/fluid/operators/fc_mkldnn_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class MKLDNNMemory {
115115

116116
template <typename T>
117117
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
118+
public:
118119
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
119120
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
120121
"It must use CPUPlace.");

0 commit comments

Comments
 (0)