@@ -153,12 +153,21 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
153
153
}
154
154
155
155
VLOG (1 ) << " Run ComputeFunc." ;
156
- auto outs = func (custom_ins, custom_attrs);
156
+ try {
157
+ auto outs = func (custom_ins, custom_attrs);
157
158
158
- VLOG (1 ) << " Custom Operator: Share outputs into ExecutionContext." ;
159
- for (size_t i = 0 ; i < outputs.size (); ++i) {
160
- auto * true_out = ctx.Output <Tensor>(outputs[i]);
161
- CustomTensorUtils::ShareDataTo (outs.at (i), true_out);
159
+ VLOG (1 ) << " Custom Operator: Share outputs into ExecutionContext." ;
160
+ for (size_t i = 0 ; i < outputs.size (); ++i) {
161
+ auto * true_out = ctx.Output <Tensor>(outputs[i]);
162
+ CustomTensorUtils::ShareDataTo (outs.at (i), true_out);
163
+ }
164
+ } catch (platform::EnforceNotMet& exception) {
165
+ throw std::move (exception);
166
+ } catch (std::exception& ex) {
167
+ PADDLE_THROW (platform::errors::External (" %s" , ex.what ()));
168
+ } catch (...) {
169
+ PADDLE_THROW (platform::errors::Fatal (
170
+ " Custom operator raises an unknown exception in rumtime." ));
162
171
}
163
172
}
164
173
@@ -475,58 +484,108 @@ void RegisterOperatorWithMetaInfo(
475
484
op_name, info.proto_ ->InitializationErrorString ()));
476
485
477
486
// InferShape
478
- PADDLE_ENFORCE_NOT_NULL (
479
- infer_shape_func,
480
- platform::errors::PreconditionNotMet (
481
- " InferShapeFn is nullptr. Need to set the InferShapeFn of custom "
482
- " operator by .SetInferShapeFn(PD_INFER_SHAPE(...))" ));
483
- info.infer_shape_ = [op_inputs, op_outputs,
484
- infer_shape_func](InferShapeContext* ctx) {
485
- std::vector<std::vector<int64_t >> input_shapes;
486
-
487
- VLOG (1 ) << " Custom Operator: InferShape - get input ddim." ;
488
- for (auto & in_name : op_inputs) {
489
- OP_INOUT_CHECK (ctx->HasInput (in_name), " Input" , in_name, " Custom" );
490
- auto ddim = ctx->GetInputDim (in_name);
491
- input_shapes.emplace_back (framework::vectorize (ddim));
492
- }
487
+ if (infer_shape_func == nullptr ) {
488
+ // use default InferShape
489
+ info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) {
490
+ PADDLE_ENFORCE_EQ (
491
+ op_inputs.size (), 1UL ,
492
+ platform::errors::Unavailable (
493
+ " Your custom operator contains multiple inputs. "
494
+ " We only allow a custom operator that contains only one input "
495
+ " and "
496
+ " only one output without setting the InferShapeFn. At this time, "
497
+ " the input shape will be directly set to the output shape.\n "
498
+ " Please set the InferShapeFn of custom "
499
+ " operator by .SetInferShapeFn(PD_INFER_SHAPE(...))" ));
500
+ PADDLE_ENFORCE_EQ (
501
+ op_outputs.size (), 1UL ,
502
+ platform::errors::Unavailable (
503
+ " Your custom operator contains multiple outputs. "
504
+ " We only allow a custom operator that contains only one input "
505
+ " and "
506
+ " only one output without setting the InferShapeFn. At this time, "
507
+ " the input shape will be directly set to the output shape.\n "
508
+ " Please set the InferShapeFn of custom "
509
+ " operator by .SetInferShapeFn(PD_INFER_SHAPE(...))" ));
510
+
511
+ VLOG (1 ) << " Custom Operator: Default InferShape - share ddim." ;
512
+ ctx->ShareDim (op_inputs[0 ], op_outputs[0 ]);
513
+ };
514
+ } else {
515
+ info.infer_shape_ = [op_inputs, op_outputs,
516
+ infer_shape_func](InferShapeContext* ctx) {
517
+ std::vector<std::vector<int64_t >> input_shapes;
518
+
519
+ VLOG (1 ) << " Custom Operator: InferShape - get input ddim." ;
520
+ for (auto & in_name : op_inputs) {
521
+ OP_INOUT_CHECK (ctx->HasInput (in_name), " Input" , in_name, " Custom" );
522
+ auto ddim = ctx->GetInputDim (in_name);
523
+ input_shapes.emplace_back (framework::vectorize (ddim));
524
+ }
493
525
494
- VLOG (1 ) << " Custom Operator: InferShape - calc output ddim." ;
495
- auto output_shapes = infer_shape_func (input_shapes);
526
+ VLOG (1 ) << " Custom Operator: InferShape - calc output ddim." ;
527
+ auto output_shapes = infer_shape_func (input_shapes);
496
528
497
- VLOG (1 ) << " Custom Operator: InferShape - set output ddim." ;
498
- for (size_t i = 0 ; i < op_outputs.size (); ++i) {
499
- ctx->SetOutputDim (op_outputs[i], framework::make_ddim (output_shapes[i]));
500
- }
501
- };
529
+ VLOG (1 ) << " Custom Operator: InferShape - set output ddim." ;
530
+ for (size_t i = 0 ; i < op_outputs.size (); ++i) {
531
+ ctx->SetOutputDim (op_outputs[i],
532
+ framework::make_ddim (output_shapes[i]));
533
+ }
534
+ };
535
+ }
502
536
503
537
// Infer Dtype
504
- PADDLE_ENFORCE_NOT_NULL (
505
- infer_dtype_func,
506
- platform::errors::PreconditionNotMet (
507
- " InferDtypeFn is nullptr. Need to set the InferDtypeFn of custom "
508
- " operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))" ));
509
- info.infer_var_type_ = [op_inputs, op_outputs,
510
- infer_dtype_func](InferVarTypeContext* ctx) {
511
- std::vector<DataType> input_dtypes;
512
-
513
- VLOG (1 ) << " Custom Operator: InferDtype - get input dtype." ;
514
- for (auto & in_name : op_inputs) {
515
- auto dtype = ctx->GetInputDataType (in_name);
516
- input_dtypes.emplace_back (
517
- CustomTensorUtils::ConvertInnerDTypeToEnumDType (dtype));
518
- }
538
+ if (infer_dtype_func == nullptr ) {
539
+ // use defalut InferDtype
540
+ info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
541
+ PADDLE_ENFORCE_EQ (
542
+ op_inputs.size (), 1UL ,
543
+ platform::errors::Unavailable (
544
+ " Your custom operator contains multiple inputs. "
545
+ " We only allow a custom operator that contains only one input "
546
+ " and "
547
+ " only one output without setting the InferDtypeFn. At this time, "
548
+ " the input dtype will be directly set to the output dtype.\n "
549
+ " Please set the InferDtypeFn of custom "
550
+ " operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))" ));
551
+ PADDLE_ENFORCE_EQ (
552
+ op_outputs.size (), 1UL ,
553
+ platform::errors::Unavailable (
554
+ " Your custom operator contains multiple outputs. "
555
+ " We only allow a custom operator that contains only one input "
556
+ " and "
557
+ " only one output without setting the InferDtypeFn. At this time, "
558
+ " the input dtype will be directly set to the output dtype.\n "
559
+ " Please set the InferDtypeFn of custom "
560
+ " operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))" ));
561
+
562
+ VLOG (1 ) << " Custom Operator: InferDtype - share dtype." ;
563
+ auto dtype = ctx->GetInputDataType (op_inputs[0 ]);
564
+ ctx->SetOutputDataType (op_outputs[0 ], dtype);
565
+ };
566
+ } else {
567
+ info.infer_var_type_ = [op_inputs, op_outputs,
568
+ infer_dtype_func](InferVarTypeContext* ctx) {
569
+ std::vector<DataType> input_dtypes;
570
+
571
+ VLOG (1 ) << " Custom Operator: InferDtype - get input dtype." ;
572
+ for (auto & in_name : op_inputs) {
573
+ auto dtype = ctx->GetInputDataType (in_name);
574
+ input_dtypes.emplace_back (
575
+ CustomTensorUtils::ConvertInnerDTypeToEnumDType (dtype));
576
+ }
519
577
520
- VLOG (1 ) << " Custom Operator: InferDtype - infer output dtype." ;
521
- auto output_dtypes = infer_dtype_func (input_dtypes);
578
+ VLOG (1 ) << " Custom Operator: InferDtype - infer output dtype." ;
579
+ auto output_dtypes = infer_dtype_func (input_dtypes);
522
580
523
- VLOG (1 ) << " Custom Operator: InferDtype - set output dtype." ;
524
- for (size_t i = 0 ; i < op_outputs.size (); ++i) {
525
- ctx->SetOutputDataType (
526
- op_outputs[i],
527
- CustomTensorUtils::ConvertEnumDTypeToInnerDType (output_dtypes[i]));
528
- }
529
- };
581
+ VLOG (1 ) << " Custom Operator: InferDtype - set output dtype." ;
582
+ for (size_t i = 0 ; i < op_outputs.size (); ++i) {
583
+ ctx->SetOutputDataType (
584
+ op_outputs[i],
585
+ CustomTensorUtils::ConvertEnumDTypeToInnerDType (output_dtypes[i]));
586
+ }
587
+ };
588
+ }
530
589
531
590
// Kernel func
532
591
RegisterOperatorKernel (op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
0 commit comments