Skip to content

Commit 52f7e77

Browse files
Aurelius84chenwhql
andauthored
[Cherry-Pick] Split Macros and Add modeling unittest (#31266)
* [CustomOp] Add Modeling with Custom op unittest (#31218) * add unittest for static/dygraph/dy2stat * add PE unittet * remove usless code * add unittest in CMakeList.txt * [CustomOp] Split build op marco & polish details (#31229) * split build op marco & polish details * revert register api del * fix other unittest * [CustomOP]Support Incremental compilation and Add Version management (#31228) * Support Incremental compilation and Add Version management * replace hash with hashlib * fix test_op_num unittest * Revert "fix test_op_num unittest" This reverts commit 2f78de9. Co-authored-by: Chen Weihang <[email protected]>
1 parent 536d9a3 commit 52f7e77

File tree

12 files changed

+646
-144
lines changed

12 files changed

+646
-144
lines changed

paddle/fluid/extension/include/op_meta_info.h

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class PD_DLL_DECL OpMetaInfoHelper;
3838

3939
using Tensor = paddle::Tensor;
4040

41+
///////////////// Util Marco Define ////////////////
42+
4143
#define PD_DISABLE_COPY_AND_ASSIGN(classname) \
4244
private: \
4345
classname(const classname&) = delete; \
@@ -65,6 +67,12 @@ using Tensor = paddle::Tensor;
6567
END_HANDLE_THE_ERROR \
6668
} while (0)
6769

70+
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
71+
struct __test_global_namespace_##uniq_name##__ {}; \
72+
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
73+
__test_global_namespace_##uniq_name##__>::value, \
74+
msg)
75+
6876
///////////////// Util Define and Function ////////////////
6977

7078
inline std::string Grad(const std::string& var_name) {
@@ -288,9 +296,9 @@ class PD_DLL_DECL OpMetaInfo {
288296
std::vector<std::string> attrs_;
289297

290298
// 2. func info
291-
KernelFunc kernel_fn_;
292-
InferShapeFunc infer_shape_fn_;
293-
InferDtypeFunc infer_dtype_fn_;
299+
KernelFunc kernel_fn_{nullptr};
300+
InferShapeFunc infer_shape_fn_{nullptr};
301+
InferDtypeFunc infer_dtype_fn_{nullptr};
294302
};
295303

296304
//////////////// Op Meta Info Map /////////////////
@@ -321,20 +329,22 @@ class PD_DLL_DECL OpMetaInfoMap {
321329

322330
class PD_DLL_DECL OpMetaInfoBuilder {
323331
public:
324-
explicit OpMetaInfoBuilder(std::string&& name);
332+
explicit OpMetaInfoBuilder(std::string&& name, size_t index);
325333
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
326334
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
327335
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
328336
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
329337
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
330338
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
331-
OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name);
332339

333340
private:
334341
// Forward Op name
335342
std::string name_;
336-
// Point to the currently constructed op meta info
343+
// ref current info ptr
337344
OpMetaInfo* info_ptr_;
345+
// The current op meta info index in vector
346+
// - 0: op, 1: grad_op, 2: grad_grad_op
347+
size_t index_;
338348
};
339349

340350
/////////////////////// Op register API /////////////////////////
@@ -350,14 +360,25 @@ void LoadCustomOperatorLib(const std::string& dso_name);
350360

351361
/////////////////////// Op register Macro /////////////////////////
352362

353-
#define PD_BUILD_OP_WITH_COUNTER(op_name, counter) \
354-
static ::paddle::OpMetaInfoBuilder __op_meta_info_##counter##__ = \
355-
::paddle::OpMetaInfoBuilder(op_name)
356-
357-
#define PD_BUILD_OP_INNER(op_name, counter) \
358-
PD_BUILD_OP_WITH_COUNTER(op_name, counter)
359-
360-
#define PD_BUILD_OP(op_name) PD_BUILD_OP_INNER(op_name, __COUNTER__)
363+
#define PD_BUILD_OP(op_name) \
364+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
365+
__reg_op__##op_name, "PD_BUILD_OP must be called in global namespace."); \
366+
static ::paddle::OpMetaInfoBuilder __op_meta_info_##op_name##__ = \
367+
::paddle::OpMetaInfoBuilder(#op_name, 0)
368+
369+
#define PD_BUILD_GRAD_OP(op_name) \
370+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
371+
__reg_grad_op__##op_name, \
372+
"PD_BUILD_GRAD_OP must be called in global namespace."); \
373+
static ::paddle::OpMetaInfoBuilder __grad_op_meta_info_##op_name##__ = \
374+
::paddle::OpMetaInfoBuilder(#op_name, 1)
375+
376+
#define PD_BUILD_DOUBLE_GRAD_OP(op_name) \
377+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
378+
__reg_grad_grad_op__##op_name, \
379+
"PD_BUILD_DOUBLE_GRAD_OP must be called in global namespace."); \
380+
static ::paddle::OpMetaInfoBuilder __grad_grad_op_meta_info_##op_name##__ = \
381+
::paddle::OpMetaInfoBuilder(#op_name, 2)
361382

362383
} // namespace paddle
363384

paddle/fluid/extension/src/op_meta_info.cc

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include <vector>
2020

2121
#include "paddle/fluid/framework/custom_operator.h"
22+
#include "paddle/fluid/platform/enforce.h"
2223

2324
namespace paddle {
2425

@@ -62,11 +63,38 @@ OpMetaInfoMap::GetMap() const {
6263

6364
//////////////// Op Meta Info Builder /////////////////
6465

65-
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) {
66+
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name, size_t index) {
67+
// 1. member assign
6668
name_ = std::forward<std::string>(name);
69+
index_ = index;
70+
71+
// 2. check and meta info build
6772
auto& info_vector = OpMetaInfoMap::Instance()[name_];
73+
// index check
74+
PADDLE_ENFORCE_EQ(
75+
info_vector.size(), index_,
76+
platform::errors::PreconditionNotMet(
77+
"The operator %s's meta info register failed. "
78+
"Please make sure you call marcos as order `PD_BUILD_OP`, "
79+
"`PD_BUILD_GRAD_OP`, `PD_BUILD_DOUBLE_GRAD_OP`.",
80+
name_));
81+
switch (index_) {
82+
case 0:
83+
break;
84+
case 1:
85+
name_ = name_ + "_grad";
86+
break;
87+
case 2:
88+
name_ = name_ + "_grad_grad";
89+
default:
90+
PADDLE_THROW(platform::errors::InvalidArgument(
91+
"Not support index `%d` when construct OpMetaInfoBuilder, "
92+
"now only support `0, 1, 2`.",
93+
index_));
94+
}
6895
auto op_meta = OpMetaInfo(name_);
6996
info_vector.emplace_back(std::move(op_meta));
97+
// 3. get current info ptr
7098
info_ptr_ = &(info_vector.back());
7199
}
72100

@@ -93,24 +121,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
93121
}
94122

95123
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
124+
PADDLE_ENFORCE_EQ(
125+
index_, 0UL,
126+
platform::errors::Unimplemented(
127+
"Currently, the InferShapeFn setting of Grad Op is not supported, "
128+
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
129+
"`X` by default."));
96130
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
97131
return *this;
98132
}
99133

100134
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
135+
PADDLE_ENFORCE_EQ(
136+
index_, 0UL,
137+
platform::errors::Unimplemented(
138+
"Currently, the InferDtypeFn setting of Grad Op is not supported, "
139+
"And backward Tensor `X@GRAD` will use the dtype of forward Tensor "
140+
"`X` by default."));
101141
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
102142
return *this;
103143
}
104144

105-
OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp(
106-
const std::string& bwd_op_name) {
107-
auto& info_vector = OpMetaInfoMap::Instance()[name_];
108-
auto op_meta = OpMetaInfo(bwd_op_name);
109-
info_vector.emplace_back(std::move(op_meta));
110-
info_ptr_ = &(info_vector.back());
111-
return *this;
112-
}
113-
114145
/////////////////////// Op register API /////////////////////////
115146

116147
void RegisterAllCustomOperator() {

paddle/fluid/framework/custom_operator.cc

Lines changed: 110 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,21 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
153153
}
154154

155155
VLOG(1) << "Run ComputeFunc.";
156-
auto outs = func(custom_ins, custom_attrs);
156+
try {
157+
auto outs = func(custom_ins, custom_attrs);
157158

158-
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
159-
for (size_t i = 0; i < outputs.size(); ++i) {
160-
auto* true_out = ctx.Output<Tensor>(outputs[i]);
161-
CustomTensorUtils::ShareDataTo(outs.at(i), true_out);
159+
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
160+
for (size_t i = 0; i < outputs.size(); ++i) {
161+
auto* true_out = ctx.Output<Tensor>(outputs[i]);
162+
CustomTensorUtils::ShareDataTo(outs.at(i), true_out);
163+
}
164+
} catch (platform::EnforceNotMet& exception) {
165+
throw std::move(exception);
166+
} catch (std::exception& ex) {
167+
PADDLE_THROW(platform::errors::External("%s", ex.what()));
168+
} catch (...) {
169+
PADDLE_THROW(platform::errors::Fatal(
170+
"Custom operator raises an unknown exception in rumtime."));
162171
}
163172
}
164173

@@ -475,58 +484,108 @@ void RegisterOperatorWithMetaInfo(
475484
op_name, info.proto_->InitializationErrorString()));
476485

477486
// InferShape
478-
PADDLE_ENFORCE_NOT_NULL(
479-
infer_shape_func,
480-
platform::errors::PreconditionNotMet(
481-
"InferShapeFn is nullptr. Need to set the InferShapeFn of custom "
482-
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
483-
info.infer_shape_ = [op_inputs, op_outputs,
484-
infer_shape_func](InferShapeContext* ctx) {
485-
std::vector<std::vector<int64_t>> input_shapes;
486-
487-
VLOG(1) << "Custom Operator: InferShape - get input ddim.";
488-
for (auto& in_name : op_inputs) {
489-
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
490-
auto ddim = ctx->GetInputDim(in_name);
491-
input_shapes.emplace_back(framework::vectorize(ddim));
492-
}
487+
if (infer_shape_func == nullptr) {
488+
// use default InferShape
489+
info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) {
490+
PADDLE_ENFORCE_EQ(
491+
op_inputs.size(), 1UL,
492+
platform::errors::Unavailable(
493+
"Your custom operator contains multiple inputs. "
494+
"We only allow a custom operator that contains only one input "
495+
"and "
496+
"only one output without setting the InferShapeFn. At this time, "
497+
"the input shape will be directly set to the output shape.\n"
498+
"Please set the InferShapeFn of custom "
499+
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
500+
PADDLE_ENFORCE_EQ(
501+
op_outputs.size(), 1UL,
502+
platform::errors::Unavailable(
503+
"Your custom operator contains multiple outputs. "
504+
"We only allow a custom operator that contains only one input "
505+
"and "
506+
"only one output without setting the InferShapeFn. At this time, "
507+
"the input shape will be directly set to the output shape.\n"
508+
"Please set the InferShapeFn of custom "
509+
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
510+
511+
VLOG(1) << "Custom Operator: Default InferShape - share ddim.";
512+
ctx->ShareDim(op_inputs[0], op_outputs[0]);
513+
};
514+
} else {
515+
info.infer_shape_ = [op_inputs, op_outputs,
516+
infer_shape_func](InferShapeContext* ctx) {
517+
std::vector<std::vector<int64_t>> input_shapes;
518+
519+
VLOG(1) << "Custom Operator: InferShape - get input ddim.";
520+
for (auto& in_name : op_inputs) {
521+
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
522+
auto ddim = ctx->GetInputDim(in_name);
523+
input_shapes.emplace_back(framework::vectorize(ddim));
524+
}
493525

494-
VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
495-
auto output_shapes = infer_shape_func(input_shapes);
526+
VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
527+
auto output_shapes = infer_shape_func(input_shapes);
496528

497-
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
498-
for (size_t i = 0; i < op_outputs.size(); ++i) {
499-
ctx->SetOutputDim(op_outputs[i], framework::make_ddim(output_shapes[i]));
500-
}
501-
};
529+
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
530+
for (size_t i = 0; i < op_outputs.size(); ++i) {
531+
ctx->SetOutputDim(op_outputs[i],
532+
framework::make_ddim(output_shapes[i]));
533+
}
534+
};
535+
}
502536

503537
// Infer Dtype
504-
PADDLE_ENFORCE_NOT_NULL(
505-
infer_dtype_func,
506-
platform::errors::PreconditionNotMet(
507-
"InferDtypeFn is nullptr. Need to set the InferDtypeFn of custom "
508-
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
509-
info.infer_var_type_ = [op_inputs, op_outputs,
510-
infer_dtype_func](InferVarTypeContext* ctx) {
511-
std::vector<DataType> input_dtypes;
512-
513-
VLOG(1) << "Custom Operator: InferDtype - get input dtype.";
514-
for (auto& in_name : op_inputs) {
515-
auto dtype = ctx->GetInputDataType(in_name);
516-
input_dtypes.emplace_back(
517-
CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype));
518-
}
538+
if (infer_dtype_func == nullptr) {
539+
// use defalut InferDtype
540+
info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
541+
PADDLE_ENFORCE_EQ(
542+
op_inputs.size(), 1UL,
543+
platform::errors::Unavailable(
544+
"Your custom operator contains multiple inputs. "
545+
"We only allow a custom operator that contains only one input "
546+
"and "
547+
"only one output without setting the InferDtypeFn. At this time, "
548+
"the input dtype will be directly set to the output dtype.\n"
549+
"Please set the InferDtypeFn of custom "
550+
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
551+
PADDLE_ENFORCE_EQ(
552+
op_outputs.size(), 1UL,
553+
platform::errors::Unavailable(
554+
"Your custom operator contains multiple outputs. "
555+
"We only allow a custom operator that contains only one input "
556+
"and "
557+
"only one output without setting the InferDtypeFn. At this time, "
558+
"the input dtype will be directly set to the output dtype.\n"
559+
"Please set the InferDtypeFn of custom "
560+
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
561+
562+
VLOG(1) << "Custom Operator: InferDtype - share dtype.";
563+
auto dtype = ctx->GetInputDataType(op_inputs[0]);
564+
ctx->SetOutputDataType(op_outputs[0], dtype);
565+
};
566+
} else {
567+
info.infer_var_type_ = [op_inputs, op_outputs,
568+
infer_dtype_func](InferVarTypeContext* ctx) {
569+
std::vector<DataType> input_dtypes;
570+
571+
VLOG(1) << "Custom Operator: InferDtype - get input dtype.";
572+
for (auto& in_name : op_inputs) {
573+
auto dtype = ctx->GetInputDataType(in_name);
574+
input_dtypes.emplace_back(
575+
CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype));
576+
}
519577

520-
VLOG(1) << "Custom Operator: InferDtype - infer output dtype.";
521-
auto output_dtypes = infer_dtype_func(input_dtypes);
578+
VLOG(1) << "Custom Operator: InferDtype - infer output dtype.";
579+
auto output_dtypes = infer_dtype_func(input_dtypes);
522580

523-
VLOG(1) << "Custom Operator: InferDtype - set output dtype.";
524-
for (size_t i = 0; i < op_outputs.size(); ++i) {
525-
ctx->SetOutputDataType(
526-
op_outputs[i],
527-
CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i]));
528-
}
529-
};
581+
VLOG(1) << "Custom Operator: InferDtype - set output dtype.";
582+
for (size_t i = 0; i < op_outputs.size(); ++i) {
583+
ctx->SetOutputDataType(
584+
op_outputs[i],
585+
CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i]));
586+
}
587+
};
588+
}
530589

531590
// Kernel func
532591
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);

python/paddle/fluid/tests/custom_op/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ if(WITH_GPU)
33
# 'test_custom_relu_op_setup/jit' compile .cc and .cu file
44
py_test(test_custom_relu_op_setup SRCS test_custom_relu_op_setup.py)
55
py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py)
6+
py_test(test_custom_relu_model SRCS test_custom_relu_model.py)
67

78
# Compiling shared library will cost some time, but running process is very fast.
89
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
910
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)
11+
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
1012
endif()
1113

1214
py_test(test_sysconfig SRCS test_sysconfig.py)

0 commit comments

Comments
 (0)