@@ -14,12 +14,14 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < iostream>
17
18
#include < string>
18
19
#include < unordered_map>
19
20
#include < vector>
20
21
21
22
#include < boost/any.hpp>
22
23
24
+ #include " paddle/fluid/extension/include/dll_decl.h"
23
25
#include " paddle/fluid/extension/include/tensor.h"
24
26
25
27
/* *
@@ -31,7 +33,7 @@ limitations under the License. */
31
33
32
34
namespace paddle {
33
35
namespace framework {
34
- class OpMetaInfoHelper ;
36
+ class PD_DLL_DECL OpMetaInfoHelper;
35
37
} // namespace framework
36
38
37
39
using Tensor = paddle::Tensor;
@@ -43,6 +45,26 @@ using Tensor = paddle::Tensor;
43
45
classname& operator =(const classname&) = delete ; \
44
46
classname& operator =(classname&&) = delete
45
47
48
+ #if defined _WIN32
49
+ #define HANDLE_THE_ERROR try {
50
+ #define END_HANDLE_THE_ERROR \
51
+ } \
52
+ catch (const std::exception& e) { \
53
+ std::cerr << e.what () << std::endl; \
54
+ throw e; \
55
+ }
56
+ #else
57
+ #define HANDLE_THE_ERROR
58
+ #define END_HANDLE_THE_ERROR
59
+ #endif
60
+
61
+ #define PD_THROW (err_msg ) \
62
+ do { \
63
+ HANDLE_THE_ERROR \
64
+ throw std::runtime_error (err_msg); \
65
+ END_HANDLE_THE_ERROR \
66
+ } while (0 )
67
+
46
68
// /////////////// Util Define and Function ////////////////
47
69
48
70
inline std::string Grad (const std::string& var_name) {
@@ -59,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
59
81
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
60
82
std::vector<boost::any> attrs);
61
83
84
+ #define PD_SPECIALIZE_ComputeCallHelper (attr_type ) \
85
+ template <typename ... Tail> \
86
+ struct ComputeCallHelper <attr_type, Tail...> { \
87
+ template <int in_idx, int attr_idx, typename ... PreviousArgs> \
88
+ static Return Compute (std::vector<Tensor> inputs, \
89
+ std::vector<boost::any> attrs, \
90
+ const PreviousArgs&... pargs) { \
91
+ try { \
92
+ attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
93
+ return ComputeCallHelper<Tail...>::template Compute<in_idx, \
94
+ attr_idx + 1 >( \
95
+ inputs, attrs, pargs..., arg); \
96
+ } catch (boost::bad_any_cast&) { \
97
+ PD_THROW ( \
98
+ " Attribute cast error in custom operator. Expected " #attr_type \
99
+ " value." ); \
100
+ } \
101
+ } \
102
+ }
103
+
62
104
template <typename T>
63
105
struct TypeTag {};
64
106
@@ -92,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
92
134
}
93
135
};
94
136
95
- // TODO(chenweihang): add support for attribute input
96
- // int attribute input (not used now)
97
- template <typename ... Tail>
98
- struct ComputeCallHelper <int , Tail...> {
99
- template <int in_idx, int attr_idx, typename ... PreviousArgs>
100
- static Return Compute (std::vector<Tensor> inputs,
101
- std::vector<boost::any> attrs,
102
- const PreviousArgs&... pargs) {
103
- try {
104
- int arg = boost::any_cast<int >(attrs[attr_idx]);
105
- return ComputeCallHelper<Tail...>::template Compute<in_idx,
106
- attr_idx + 1 >(
107
- inputs, attrs, pargs..., arg);
108
- } catch (boost::bad_any_cast&) {
109
- throw std::runtime_error (
110
- " Attribute cast error in custom operator. Expected int value." );
111
- }
112
- }
113
- };
114
-
137
+ PD_SPECIALIZE_ComputeCallHelper (bool );
138
+ PD_SPECIALIZE_ComputeCallHelper (int );
139
+ PD_SPECIALIZE_ComputeCallHelper (float );
140
+ PD_SPECIALIZE_ComputeCallHelper (int64_t );
141
+ PD_SPECIALIZE_ComputeCallHelper (std::string);
142
+ PD_SPECIALIZE_ComputeCallHelper (std::vector<int >);
143
+ PD_SPECIALIZE_ComputeCallHelper (std::vector<float >);
144
+ PD_SPECIALIZE_ComputeCallHelper (std::vector<int64_t >);
145
+ PD_SPECIALIZE_ComputeCallHelper (std::vector<std::string>);
146
+ // TODO(chenweihang): support other attribute type if needed.
147
+ // Why not support other attribute type here?
148
+ // - boost::blank, std::vector<bool> and std::vector<double>
149
+ // are not used in op
150
+ // - BlockDesc* and std::vector<BlockDesc*> are used in framework
115
151
// end: base template
116
152
template <typename T>
117
153
struct ComputeCallHelper <TypeTag<T>> {
@@ -220,13 +256,26 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
220
256
221
257
// //////////////////// Op Meta Info //////////////////////
222
258
223
- class OpMetaInfo {
259
+ class PD_DLL_DECL OpMetaInfo {
224
260
public:
225
261
explicit OpMetaInfo (const std::string& op_name) : name_(op_name) {}
262
+
263
+ // format: {"<name1>", "<name2>", ...}
226
264
OpMetaInfo& Inputs (std::vector<std::string>&& inputs);
265
+
266
+ // format: {"<name1>", "<name2>", ...}
227
267
OpMetaInfo& Outputs (std::vector<std::string>&& outputs);
268
+
269
+ // format: {"<name1>:<type1>", "<name1>:<type1>", ...}
270
+ OpMetaInfo& Attrs (std::vector<std::string>&& attrs);
271
+
272
+ // format: PD_KERNEL(...)
228
273
OpMetaInfo& SetKernelFn (KernelFunc&& func);
274
+
275
+ // format: PD_INFER_SHAPE(...)
229
276
OpMetaInfo& SetInferShapeFn (InferShapeFunc&& func);
277
+
278
+ // format: PD_INFER_DTYPE(...)
230
279
OpMetaInfo& SetInferDtypeFn (InferDtypeFunc&& func);
231
280
232
281
private:
@@ -246,7 +295,7 @@ class OpMetaInfo {
246
295
247
296
// ////////////// Op Meta Info Map /////////////////
248
297
249
- class OpMetaInfoMap {
298
+ class PD_DLL_DECL OpMetaInfoMap {
250
299
public:
251
300
// this function's impl should keep in header file.
252
301
// if move to cc file, meta info can not be added
@@ -270,14 +319,15 @@ class OpMetaInfoMap {
270
319
271
320
// ////////////// Op Meta Info Builder /////////////////
272
321
273
- class OpMetaInfoBuilder {
322
+ class PD_DLL_DECL OpMetaInfoBuilder {
274
323
public:
275
324
explicit OpMetaInfoBuilder (std::string&& name);
276
325
OpMetaInfoBuilder& Inputs (std::vector<std::string>&& inputs);
277
326
OpMetaInfoBuilder& Outputs (std::vector<std::string>&& outputs);
278
- OpMetaInfoBuilder& SetKernelFn (KernelFunc&& func);
279
- OpMetaInfoBuilder& SetInferShapeFn (InferShapeFunc&& func);
280
- OpMetaInfoBuilder& SetInferDtypeFn (InferDtypeFunc&& func);
327
+ OpMetaInfoBuilder& Attrs (std::vector<std::string>&& attrs);
328
+ OpMetaInfoBuilder& SetKernelFn (KernelFunc func);
329
+ OpMetaInfoBuilder& SetInferShapeFn (InferShapeFunc func);
330
+ OpMetaInfoBuilder& SetInferDtypeFn (InferDtypeFunc func);
281
331
OpMetaInfoBuilder& SetBackwardOp (const std::string& bwd_op_name);
282
332
283
333
private:
@@ -317,8 +367,12 @@ void LoadCustomOperatorLib(const std::string& dso_name);
317
367
extern " C" {
318
368
#endif
319
369
370
+ #if defined(_WIN32)
320
371
// C-API to get global OpMetaInfoMap.
321
- paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap ();
372
+ __declspec (dllexport) inline paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
373
+ return paddle::OpMetaInfoMap::Instance ();
374
+ }
375
+ #endif // _WIN32
322
376
323
377
#ifdef __cplusplus
324
378
}
0 commit comments