Skip to content

Commit d3e6095

Browse files
authored
[Cherry-pick] The Second part of new custom op extension in 2.0.1 (#31237)
[Cherry-pick] The Second part of new custom op extension in 2.0.1
1 parent 34092ab commit d3e6095

32 files changed

+1705
-504
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
293293
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
294294
set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
295295

296+
add_definitions(-DPADDLE_DLL_EXPORT)
297+
296298
if(ON_INFER)
297299
# you can trun off the paddle fluid and inference lib by set ON_INFER=OFF
298300
message(STATUS "On inference mode, will take place some specific optimization.")

cmake/generic.cmake

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -792,17 +792,15 @@ function(py_test TARGET_NAME)
792792

793793
if(WITH_COVERAGE)
794794
add_test(NAME ${TARGET_NAME}
795-
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
796-
FLAGS_cpu_deterministic=true
797-
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
798-
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
799-
${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS}
800-
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
795+
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
796+
FLAGS_cpu_deterministic=true ${py_test_ENVS}
797+
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
798+
${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS}
799+
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
801800
else()
802801
add_test(NAME ${TARGET_NAME}
803802
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
804-
FLAGS_cpu_deterministic=true
805-
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
803+
FLAGS_cpu_deterministic=true ${py_test_ENVS}
806804
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
807805
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
808806
endif()

paddle/fluid/extension/include/all.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ limitations under the License. */
1818
#error C++11 or later compatible compiler is required to use Paddle.
1919
#endif
2020

21+
#ifdef _WIN32
22+
#ifndef NOMINMAX
23+
#define NOMINMAX // msvc max/min macro conflict with std::min/max
24+
#endif
25+
#endif
26+
2127
#include "paddle/fluid/extension/include/dispatch.h"
2228
#include "paddle/fluid/extension/include/dtype.h"
2329
#include "paddle/fluid/extension/include/op_meta_info.h"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#if defined(_WIN32)
18+
#ifndef PD_DLL_DECL
19+
#ifdef PADDLE_DLL_EXPORT
20+
#define PD_DLL_DECL __declspec(dllexport)
21+
#else
22+
#define PD_DLL_DECL __declspec(dllimport)
23+
#endif // PADDLE_DLL_EXPORT
24+
#endif // PD_DLL_DECL
25+
#else
26+
#define PD_DLL_DECL
27+
#endif // _WIN32

paddle/fluid/extension/include/op_meta_info.h

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <iostream>
1718
#include <string>
1819
#include <unordered_map>
1920
#include <vector>
2021

2122
#include <boost/any.hpp>
2223

24+
#include "paddle/fluid/extension/include/dll_decl.h"
2325
#include "paddle/fluid/extension/include/tensor.h"
2426

2527
/**
@@ -31,7 +33,7 @@ limitations under the License. */
3133

3234
namespace paddle {
3335
namespace framework {
34-
class OpMetaInfoHelper;
36+
class PD_DLL_DECL OpMetaInfoHelper;
3537
} // namespace framework
3638

3739
using Tensor = paddle::Tensor;
@@ -43,6 +45,26 @@ using Tensor = paddle::Tensor;
4345
classname& operator=(const classname&) = delete; \
4446
classname& operator=(classname&&) = delete
4547

48+
#if defined _WIN32
49+
#define HANDLE_THE_ERROR try {
50+
#define END_HANDLE_THE_ERROR \
51+
} \
52+
catch (const std::exception& e) { \
53+
std::cerr << e.what() << std::endl; \
54+
throw e; \
55+
}
56+
#else
57+
#define HANDLE_THE_ERROR
58+
#define END_HANDLE_THE_ERROR
59+
#endif
60+
61+
#define PD_THROW(err_msg) \
62+
do { \
63+
HANDLE_THE_ERROR \
64+
throw std::runtime_error(err_msg); \
65+
END_HANDLE_THE_ERROR \
66+
} while (0)
67+
4668
///////////////// Util Define and Function ////////////////
4769

4870
inline std::string Grad(const std::string& var_name) {
@@ -59,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
5981
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
6082
std::vector<boost::any> attrs);
6183

84+
#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
85+
template <typename... Tail> \
86+
struct ComputeCallHelper<attr_type, Tail...> { \
87+
template <int in_idx, int attr_idx, typename... PreviousArgs> \
88+
static Return Compute(std::vector<Tensor> inputs, \
89+
std::vector<boost::any> attrs, \
90+
const PreviousArgs&... pargs) { \
91+
try { \
92+
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
93+
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
94+
attr_idx + 1>( \
95+
inputs, attrs, pargs..., arg); \
96+
} catch (boost::bad_any_cast&) { \
97+
PD_THROW( \
98+
"Attribute cast error in custom operator. Expected " #attr_type \
99+
" value."); \
100+
} \
101+
} \
102+
}
103+
62104
template <typename T>
63105
struct TypeTag {};
64106

@@ -92,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
92134
}
93135
};
94136

95-
// TODO(chenweihang): add support for attribute input
96-
// int attribute input (not used now)
97-
template <typename... Tail>
98-
struct ComputeCallHelper<int, Tail...> {
99-
template <int in_idx, int attr_idx, typename... PreviousArgs>
100-
static Return Compute(std::vector<Tensor> inputs,
101-
std::vector<boost::any> attrs,
102-
const PreviousArgs&... pargs) {
103-
try {
104-
int arg = boost::any_cast<int>(attrs[attr_idx]);
105-
return ComputeCallHelper<Tail...>::template Compute<in_idx,
106-
attr_idx + 1>(
107-
inputs, attrs, pargs..., arg);
108-
} catch (boost::bad_any_cast&) {
109-
throw std::runtime_error(
110-
"Attribute cast error in custom operator. Expected int value.");
111-
}
112-
}
113-
};
114-
137+
PD_SPECIALIZE_ComputeCallHelper(bool);
138+
PD_SPECIALIZE_ComputeCallHelper(int);
139+
PD_SPECIALIZE_ComputeCallHelper(float);
140+
PD_SPECIALIZE_ComputeCallHelper(int64_t);
141+
PD_SPECIALIZE_ComputeCallHelper(std::string);
142+
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
143+
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
144+
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
145+
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
146+
// TODO(chenweihang): support other attribute type if needed.
147+
// Why not support other attribute type here?
148+
// - boost::blank, std::vector<bool> and std::vector<double>
149+
// are not used in op
150+
// - BlockDesc* and std::vector<BlockDesc*> are used in framework
115151
// end: base template
116152
template <typename T>
117153
struct ComputeCallHelper<TypeTag<T>> {
@@ -220,13 +256,26 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
220256

221257
////////////////////// Op Meta Info //////////////////////
222258

223-
class OpMetaInfo {
259+
class PD_DLL_DECL OpMetaInfo {
224260
public:
225261
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
262+
263+
// format: {"<name1>", "<name2>", ...}
226264
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);
265+
266+
// format: {"<name1>", "<name2>", ...}
227267
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
268+
269+
// format: {"<name1>:<type1>", "<name1>:<type1>", ...}
270+
OpMetaInfo& Attrs(std::vector<std::string>&& attrs);
271+
272+
// format: PD_KERNEL(...)
228273
OpMetaInfo& SetKernelFn(KernelFunc&& func);
274+
275+
// format: PD_INFER_SHAPE(...)
229276
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);
277+
278+
// format: PD_INFER_DTYPE(...)
230279
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
231280

232281
private:
@@ -246,7 +295,7 @@ class OpMetaInfo {
246295

247296
//////////////// Op Meta Info Map /////////////////
248297

249-
class OpMetaInfoMap {
298+
class PD_DLL_DECL OpMetaInfoMap {
250299
public:
251300
// this function's impl should keep in header file.
252301
// if move to cc file, meta info can not be added
@@ -270,14 +319,15 @@ class OpMetaInfoMap {
270319

271320
//////////////// Op Meta Info Builder /////////////////
272321

273-
class OpMetaInfoBuilder {
322+
class PD_DLL_DECL OpMetaInfoBuilder {
274323
public:
275324
explicit OpMetaInfoBuilder(std::string&& name);
276325
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
277326
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
278-
OpMetaInfoBuilder& SetKernelFn(KernelFunc&& func);
279-
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc&& func);
280-
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc&& func);
327+
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
328+
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
329+
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
330+
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
281331
OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name);
282332

283333
private:
@@ -317,8 +367,12 @@ void LoadCustomOperatorLib(const std::string& dso_name);
317367
extern "C" {
318368
#endif
319369

370+
#if defined(_WIN32)
320371
// C-API to get global OpMetaInfoMap.
321-
paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap();
372+
__declspec(dllexport) inline paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
373+
return paddle::OpMetaInfoMap::Instance();
374+
}
375+
#endif // _WIN32
322376

323377
#ifdef __cplusplus
324378
}

paddle/fluid/extension/include/tensor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ limitations under the License. */
1616

1717
#include <memory>
1818
#include <vector>
19+
#include "paddle/fluid/extension/include/dll_decl.h"
1920
#include "paddle/fluid/extension/include/dtype.h"
2021
#include "paddle/fluid/extension/include/place.h"
2122

2223
namespace paddle {
2324
namespace framework {
2425
class CustomTensorUtils;
2526
} // namespace framework
26-
class Tensor {
27+
class PD_DLL_DECL Tensor {
2728
public:
2829
/// \brief Construct a Tensor on target Place for CustomOp.
2930
/// Generally it's only used for user to create Tensor.

paddle/fluid/extension/src/op_meta_info.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
3232
outputs_ = std::forward<std::vector<std::string>>(outputs);
3333
return *this;
3434
}
35+
OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
36+
attrs_ = std::forward<std::vector<std::string>>(attrs);
37+
return *this;
38+
}
3539
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
3640
kernel_fn_ = std::forward<KernelFunc>(func);
3741
return *this;
@@ -78,17 +82,22 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
7882
return *this;
7983
}
8084

81-
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc&& func) {
85+
OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
86+
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
87+
return *this;
88+
}
89+
90+
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
8291
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
8392
return *this;
8493
}
8594

86-
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc&& func) {
95+
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
8796
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
8897
return *this;
8998
}
9099

91-
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc&& func) {
100+
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
92101
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
93102
return *this;
94103
}
@@ -114,10 +123,17 @@ void LoadCustomOperatorLib(const std::string& dso_name) {
114123
}
115124
} // namespace paddle
116125

126+
#ifdef __cplusplus
117127
extern "C" {
128+
#endif
118129

130+
#ifndef _WIN32
131+
// C-API to get global OpMetaInfoMap.
119132
paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
120133
return paddle::OpMetaInfoMap::Instance();
121134
}
135+
#endif
122136

137+
#ifdef __cplusplus
123138
} // end extern "C"
139+
#endif

0 commit comments

Comments
 (0)