@@ -21,6 +21,7 @@ limitations under the License. */
21
21
#include " paddle/fluid/framework/lod_tensor.h"
22
22
#include " paddle/fluid/framework/operator.h"
23
23
#include " paddle/fluid/framework/shape_inference.h"
24
+ #include " paddle/fluid/framework/shape_runtime_infer.h"
24
25
#include " paddle/fluid/framework/var_type.h"
25
26
#include " paddle/fluid/platform/profiler.h"
26
27
@@ -458,187 +459,147 @@ bool OpSupportGPU(const std::string& op_type) {
458
459
return false ;
459
460
}
460
461
461
- class RuntimeInferShapeContext : public InferShapeContext {
462
- public:
463
- RuntimeInferShapeContext (const OperatorBase& op, const Scope& scope)
464
- : op_(op), scope_(scope) {}
465
-
466
- bool HasInput (const std::string& name) const override {
467
- if (!op_.HasInputs (name)) {
468
- return false ;
469
- }
470
- auto & ins = Inputs (name);
471
- size_t length = ins.size ();
472
- if (length == 0 ) {
473
- return false ;
474
- }
475
- PADDLE_ENFORCE_EQ (length, 1UL ,
476
- " Input %s should not have more than one inputs" , name);
477
- auto ipt = ins[0 ];
478
- auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
479
- return var != nullptr ;
462
+ bool RuntimeInferShapeContext::HasInput (const std::string& name) const {
463
+ if (!op_.HasInputs (name)) {
464
+ return false ;
480
465
}
481
-
482
- bool HasOutput (const std::string& name) const override {
483
- if (!op_.HasOutputs (name)) {
484
- return false ;
485
- }
486
- auto & outs = Outputs (name);
487
- size_t length = outs.size ();
488
- if (length == 0 ) {
489
- return false ;
490
- }
491
- PADDLE_ENFORCE_EQ (length, 1UL ,
492
- " Output %s should not have more than one inputs" , name);
493
- auto ipt = outs[0 ];
494
- auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
495
- return var != nullptr ;
466
+ auto & ins = Inputs (name);
467
+ size_t length = ins.size ();
468
+ if (length == 0 ) {
469
+ return false ;
496
470
}
471
+ PADDLE_ENFORCE_EQ (length, 1UL ,
472
+ " Input %s should not have more than one inputs" , name);
473
+ auto ipt = ins[0 ];
474
+ auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
475
+ return var != nullptr ;
476
+ }
497
477
498
- bool HasInputs (const std::string& name) const override {
499
- if (!op_.HasInputs (name)) {
500
- return false ;
501
- }
502
- auto inputs = op_.Inputs (name);
503
- if (inputs.empty ()) {
504
- return false ;
505
- }
506
- for (auto & input : inputs) {
507
- if (scope_.FindVar (input) == nullptr ) {
508
- return false ;
509
- }
510
- }
511
- return true ;
478
+ bool RuntimeInferShapeContext::HasOutput (const std::string& name) const {
479
+ if (!op_.HasOutputs (name)) {
480
+ return false ;
512
481
}
482
+ auto & outs = Outputs (name);
483
+ size_t length = outs.size ();
484
+ if (length == 0 ) {
485
+ return false ;
486
+ }
487
+ PADDLE_ENFORCE_EQ (length, 1UL ,
488
+ " Output %s should not have more than one inputs" , name);
489
+ auto ipt = outs[0 ];
490
+ auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
491
+ return var != nullptr ;
492
+ }
513
493
514
- bool HasOutputs (const std::string& name) const override {
515
- if (!op_.HasOutputs (name)) {
516
- return false ;
517
- }
518
- auto outputs = op_.Outputs (name);
519
- if (outputs.empty ()) {
494
+ bool RuntimeInferShapeContext::HasInputs (const std::string& name) const {
495
+ if (!op_.HasInputs (name)) {
496
+ return false ;
497
+ }
498
+ auto inputs = op_.Inputs (name);
499
+ if (inputs.empty ()) {
500
+ return false ;
501
+ }
502
+ for (auto & input : inputs) {
503
+ if (scope_.FindVar (input) == nullptr ) {
520
504
return false ;
521
505
}
522
- for (auto & output : outputs) {
523
- if (scope_.FindVar (output) == nullptr ) {
524
- return false ;
525
- }
526
- }
527
- return true ;
528
506
}
507
+ return true ;
508
+ }
529
509
530
- AttrReader Attrs () const override { return AttrReader (op_.Attrs ()); }
531
-
532
- const std::vector<std::string>& Inputs (
533
- const std::string& name) const override {
534
- return op_.Inputs (name);
510
+ bool RuntimeInferShapeContext::HasOutputs (const std::string& name) const {
511
+ if (!op_.HasOutputs (name)) {
512
+ return false ;
535
513
}
536
-
537
- const std::vector<std::string>& Outputs (
538
- const std::string& name) const override {
539
- return op_.Outputs (name);
514
+ auto outputs = op_.Outputs (name);
515
+ if (outputs.empty ()) {
516
+ return false ;
540
517
}
518
+ for (auto & output : outputs) {
519
+ if (scope_.FindVar (output) == nullptr ) {
520
+ return false ;
521
+ }
522
+ }
523
+ return true ;
524
+ }
541
525
542
- void ShareLoD (const std::string& in, const std::string& out, size_t i = 0 ,
543
- size_t j = 0 ) const override {
544
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
545
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
546
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
547
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
548
- if (!in_var->IsType <LoDTensor>()) return ;
549
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
550
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
551
- auto in_tensor = in_var->Get <LoDTensor>();
552
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
553
- out_tensor->set_lod (in_tensor.lod ());
526
+ void RuntimeInferShapeContext::ShareLoD (const std::string& in,
527
+ const std::string& out, size_t i,
528
+ size_t j) const {
529
+ PADDLE_ENFORCE_LT (i, Inputs (in).size ());
530
+ PADDLE_ENFORCE_LT (j, Outputs (out).size ());
531
+ Variable* in_var = scope_.FindVar (Inputs (in)[i]);
532
+ Variable* out_var = scope_.FindVar (Outputs (out)[j]);
533
+ if (!in_var->IsType <LoDTensor>()) return ;
534
+ PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
535
+ " The %d-th output of Output(%s) must be LoDTensor." , j, out);
536
+ auto in_tensor = in_var->Get <LoDTensor>();
537
+ auto * out_tensor = out_var->GetMutable <LoDTensor>();
538
+ out_tensor->set_lod (in_tensor.lod ());
554
539
555
540
// TODO(dzhwinter) : reuse ShareLoD in most operators.
556
541
// Need to call ShareLayout explicitly in sequence related ops.
557
542
// Shall we have a better method to shared info between in/out Tensor?
558
543
#ifdef PADDLE_WITH_MKLDNN
559
- // Fix me: ugly workaround below
560
- // Correct solution:
561
- // set_layout() should NOT be called here (i.e. ShareLoD). Instead,
562
- // layout of output tensor should be set "manually" in Compute()
563
- // of each OPKernel. The reason layout should NOT be shared between
564
- // input and output "automatically" (now by InferShape()->ShareLoD())
565
- // is that layout transform may occur after InferShape().
566
- // Workaround:
567
- // Skip set_layout() when input layout is kMKLDNN
568
- // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
569
- // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
570
- // in Compute()
571
- if (in_tensor.layout () != DataLayout::kMKLDNN )
544
+ // Fix me: ugly workaround below
545
+ // Correct solution:
546
+ // set_layout() should NOT be called here (i.e. ShareLoD). Instead,
547
+ // layout of output tensor should be set "manually" in Compute()
548
+ // of each OPKernel. The reason layout should NOT be shared between
549
+ // input and output "automatically" (now by InferShape()->ShareLoD())
550
+ // is that layout transform may occur after InferShape().
551
+ // Workaround:
552
+ // Skip set_layout() when input layout is kMKLDNN
553
+ // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
554
+ // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
555
+ // in Compute()
556
+ if (in_tensor.layout () != DataLayout::kMKLDNN )
572
557
#endif
573
- out_tensor->set_layout (in_tensor.layout ());
574
- }
575
-
576
- void ShareLayout (const std::string& in, const std::string& out, size_t i = 0 ,
577
- size_t j = 0 ) const {
578
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
579
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
580
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
581
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
582
- if (!in_var->IsType <LoDTensor>()) return ;
583
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
584
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
585
- auto in_tensor = in_var->Get <LoDTensor>();
586
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
587
558
out_tensor->set_layout (in_tensor.layout ());
588
- }
589
-
590
- bool IsRuntime () const override { return true ; }
591
-
592
- protected:
593
- DDim GetDim (const std::string& name) const override {
594
- Variable* var = scope_.FindVar (name);
595
- PADDLE_ENFORCE_NOT_NULL (var);
596
- if (var->IsType <LoDTensor>()) {
597
- return var->Get <LoDTensor>().dims ();
598
- } else if (var->IsType <SelectedRows>()) {
599
- return var->Get <SelectedRows>().GetCompleteDims ();
600
- } else {
601
- PADDLE_THROW (
602
- " Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
603
- " type_id is %s." ,
604
- name, var->Type ().name ());
605
- }
606
- }
607
-
608
- std::vector<DDim> GetRepeatedDims (const std::string& name) const override {
609
- PADDLE_THROW (" Only compile time support this method" );
610
- }
611
-
612
- void SetDim (const std::string& name, const DDim& dim) override {
613
- Variable* var = scope_.FindVar (name);
614
- if (var->IsType <LoDTensor>()) {
615
- var->GetMutable <LoDTensor>()->Resize (dim);
616
- } else if (var->IsType <SelectedRows>()) {
617
- var->GetMutable <SelectedRows>()->set_height (dim[0 ]);
618
- } else {
619
- PADDLE_THROW (" Variable %s type_id %s, expect LoDTensor/SelectedRows." ,
620
- name, var->Type ().name ());
621
- }
622
- }
623
-
624
- void SetRepeatedDims (const std::string& name,
625
- const std::vector<DDim>& dims) override {
626
- PADDLE_THROW (" Only compile time support this method" );
627
- }
559
+ }
628
560
629
- proto::VarType::Type GetVarType (const std::string& name) const override {
630
- auto * var = scope_.FindVar (name);
631
- return ToVarType (var->Type ());
561
+ void RuntimeInferShapeContext::ShareLayout (const std::string& in,
562
+ const std::string& out, size_t i,
563
+ size_t j) const {
564
+ PADDLE_ENFORCE_LT (i, Inputs (in).size ());
565
+ PADDLE_ENFORCE_LT (j, Outputs (out).size ());
566
+ Variable* in_var = scope_.FindVar (Inputs (in)[i]);
567
+ Variable* out_var = scope_.FindVar (Outputs (out)[j]);
568
+ if (!in_var->IsType <LoDTensor>()) return ;
569
+ PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
570
+ " The %d-th output of Output(%s) must be LoDTensor." , j, out);
571
+ auto in_tensor = in_var->Get <LoDTensor>();
572
+ auto * out_tensor = out_var->GetMutable <LoDTensor>();
573
+ out_tensor->set_layout (in_tensor.layout ());
574
+ }
575
+
576
+ DDim RuntimeInferShapeContext::GetDim (const std::string& name) const {
577
+ Variable* var = scope_.FindVar (name);
578
+ PADDLE_ENFORCE_NOT_NULL (var);
579
+ if (var->IsType <LoDTensor>()) {
580
+ return var->Get <LoDTensor>().dims ();
581
+ } else if (var->IsType <SelectedRows>()) {
582
+ return var->Get <SelectedRows>().GetCompleteDims ();
583
+ } else {
584
+ PADDLE_THROW (
585
+ " Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
586
+ " type_id is %s." ,
587
+ name, var->Type ().name ());
632
588
}
589
+ }
633
590
634
- InferShapeVarPtr GetVarPtr (const std::string& name) override {
635
- return scope_.FindVar (name);
591
+ void RuntimeInferShapeContext::SetDim (const std::string& name,
592
+ const DDim& dim) {
593
+ Variable* var = scope_.FindVar (name);
594
+ if (var->IsType <LoDTensor>()) {
595
+ var->GetMutable <LoDTensor>()->Resize (dim);
596
+ } else if (var->IsType <SelectedRows>()) {
597
+ var->GetMutable <SelectedRows>()->set_height (dim[0 ]);
598
+ } else {
599
+ PADDLE_THROW (" Variable %s type_id %s, expect LoDTensor/SelectedRows." , name,
600
+ var->Type ().name ());
636
601
}
637
-
638
- private:
639
- const OperatorBase& op_;
640
- const Scope& scope_;
641
- };
602
+ }
642
603
643
604
static void CheckTensorNANOrInf (const std::string& name,
644
605
const framework::Tensor& tensor) {
0 commit comments