@@ -21,7 +21,6 @@ limitations under the License. */
21
21
#include " paddle/fluid/framework/lod_tensor.h"
22
22
#include " paddle/fluid/framework/operator.h"
23
23
#include " paddle/fluid/framework/shape_inference.h"
24
- #include " paddle/fluid/framework/shape_runtime_infer.h"
25
24
#include " paddle/fluid/framework/var_type.h"
26
25
#include " paddle/fluid/platform/profiler.h"
27
26
@@ -459,147 +458,184 @@ bool OpSupportGPU(const std::string& op_type) {
459
458
return false ;
460
459
}
461
460
462
- bool RuntimeInferShapeContext::HasInput (const std::string& name) const {
463
- if (!op_.HasInputs (name)) {
464
- return false ;
465
- }
466
- auto & ins = Inputs (name);
467
- size_t length = ins.size ();
468
- if (length == 0 ) {
469
- return false ;
470
- }
471
- PADDLE_ENFORCE_EQ (length, 1UL ,
472
- " Input %s should not have more than one inputs" , name);
473
- auto ipt = ins[0 ];
474
- auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
475
- return var != nullptr ;
476
- }
461
+ class RuntimeInferShapeContext : public InferShapeContext {
462
+ public:
463
+ RuntimeInferShapeContext (const OperatorBase& op, const Scope& scope)
464
+ : op_(op), scope_(scope) {}
477
465
478
- bool RuntimeInferShapeContext::HasOutput (const std::string& name) const {
479
- if (!op_.HasOutputs (name)) {
480
- return false ;
481
- }
482
- auto & outs = Outputs (name);
483
- size_t length = outs.size ();
484
- if (length == 0 ) {
485
- return false ;
486
- }
487
- PADDLE_ENFORCE_EQ (length, 1UL ,
488
- " Output %s should not have more than one inputs" , name);
489
- auto ipt = outs[0 ];
490
- auto * var = ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
491
- return var != nullptr ;
492
- }
466
+ bool HasInput (const std::string& name) const override {
467
+ // has only one input
468
+ const auto & ins = op_.Inputs ();
469
+ auto it = ins.find (name);
470
+ if (it == ins.end ()) {
471
+ return false ;
472
+ }
473
+ const auto & in = it->second ;
493
474
494
- bool RuntimeInferShapeContext::HasInputs (const std::string& name) const {
495
- if (!op_.HasInputs (name)) {
496
- return false ;
497
- }
498
- auto inputs = op_.Inputs (name);
499
- if (inputs.empty ()) {
500
- return false ;
501
- }
502
- for (auto & input : inputs) {
503
- if (scope_.FindVar (input) == nullptr ) {
475
+ if (in.size () != 1 || in[0 ] == kEmptyVarName ) {
504
476
return false ;
505
477
}
478
+ return scope_.FindVar (in[0 ]) != nullptr ;
506
479
}
507
- return true ;
508
- }
509
480
510
- bool RuntimeInferShapeContext::HasOutputs (const std::string& name) const {
511
- if (!op_.HasOutputs (name)) {
512
- return false ;
481
+ bool HasOutput (const std::string& name) const override {
482
+ // has only one output
483
+ const auto & outs = op_.Outputs ();
484
+ auto it = outs.find (name);
485
+ if (it == outs.end ()) {
486
+ return false ;
487
+ }
488
+ const auto & out = it->second ;
489
+ if (out.size () != 1 || out[0 ] == kEmptyVarName ) {
490
+ return false ;
491
+ }
492
+ return scope_.FindVar (out[0 ]) != nullptr ;
513
493
}
514
- auto outputs = op_.Outputs (name);
515
- if (outputs.empty ()) {
516
- return false ;
494
+
495
+ bool HasInputs (const std::string& name) const override {
496
+ if (!op_.HasInputs (name)) {
497
+ return false ;
498
+ }
499
+ auto inputs = op_.Inputs (name);
500
+ if (inputs.empty ()) {
501
+ return false ;
502
+ }
503
+ for (auto & input : inputs) {
504
+ if (scope_.FindVar (input) == nullptr ) {
505
+ return false ;
506
+ }
507
+ }
508
+ return true ;
517
509
}
518
- for (auto & output : outputs) {
519
- if (scope_.FindVar (output) == nullptr ) {
510
+
511
+ bool HasOutputs (const std::string& name) const override {
512
+ if (!op_.HasOutputs (name)) {
513
+ return false ;
514
+ }
515
+ auto outputs = op_.Outputs (name);
516
+ if (outputs.empty ()) {
520
517
return false ;
521
518
}
519
+ for (auto & output : outputs) {
520
+ if (scope_.FindVar (output) == nullptr ) {
521
+ return false ;
522
+ }
523
+ }
524
+ return true ;
522
525
}
523
- return true ;
524
- }
525
526
526
- void RuntimeInferShapeContext::ShareLoD (const std::string& in,
527
- const std::string& out, size_t i,
528
- size_t j) const {
529
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
530
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
531
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
532
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
533
- if (!in_var->IsType <LoDTensor>()) return ;
534
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
535
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
536
- auto in_tensor = in_var->Get <LoDTensor>();
537
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
538
- out_tensor->set_lod (in_tensor.lod ());
527
+ AttrReader Attrs () const override { return AttrReader (op_.Attrs ()); }
528
+
529
+ const std::vector<std::string>& Inputs (
530
+ const std::string& name) const override {
531
+ return op_.Inputs (name);
532
+ }
533
+
534
+ const std::vector<std::string>& Outputs (
535
+ const std::string& name) const override {
536
+ return op_.Outputs (name);
537
+ }
538
+
539
+ void ShareLoD (const std::string& in, const std::string& out, size_t i = 0 ,
540
+ size_t j = 0 ) const override {
541
+ PADDLE_ENFORCE_LT (i, Inputs (in).size ());
542
+ PADDLE_ENFORCE_LT (j, Outputs (out).size ());
543
+ Variable* in_var = scope_.FindVar (Inputs (in)[i]);
544
+ Variable* out_var = scope_.FindVar (Outputs (out)[j]);
545
+ if (!in_var->IsType <LoDTensor>()) return ;
546
+ PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
547
+ " The %d-th output of Output(%s) must be LoDTensor." , j, out);
548
+ auto in_tensor = in_var->Get <LoDTensor>();
549
+ auto * out_tensor = out_var->GetMutable <LoDTensor>();
550
+ out_tensor->set_lod (in_tensor.lod ());
539
551
540
552
// TODO(dzhwinter) : reuse ShareLoD in most operators.
541
553
// Need to call ShareLayout explicitly in sequence related ops.
542
554
// Shall we have a better method to shared info between in/out Tensor?
543
555
#ifdef PADDLE_WITH_MKLDNN
544
- // Fix me: ugly workaround below
545
- // Correct solution:
546
- // set_layout() should NOT be called here (i.e. ShareLoD). Instead,
547
- // layout of output tensor should be set "manually" in Compute()
548
- // of each OPKernel. The reason layout should NOT be shared between
549
- // input and output "automatically" (now by InferShape()->ShareLoD())
550
- // is that layout transform may occur after InferShape().
551
- // Workaround:
552
- // Skip set_layout() when input layout is kMKLDNN
553
- // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
554
- // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
555
- // in Compute()
556
- if (in_tensor.layout () != DataLayout::kMKLDNN )
556
+ // Fix me: ugly workaround below
557
+ // Correct solution:
558
+ // set_layout() should NOT be called here (i.e. ShareLoD). Instead,
559
+ // layout of output tensor should be set "manually" in Compute()
560
+ // of each OPKernel. The reason layout should NOT be shared between
561
+ // input and output "automatically" (now by InferShape()->ShareLoD())
562
+ // is that layout transform may occur after InferShape().
563
+ // Workaround:
564
+ // Skip set_layout() when input layout is kMKLDNN
565
+ // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
566
+ // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
567
+ // in Compute()
568
+ if (in_tensor.layout () != DataLayout::kMKLDNN )
557
569
#endif
570
+ out_tensor->set_layout (in_tensor.layout ());
571
+ }
572
+
573
+ void ShareLayout (const std::string& in, const std::string& out, size_t i = 0 ,
574
+ size_t j = 0 ) const {
575
+ PADDLE_ENFORCE_LT (i, Inputs (in).size ());
576
+ PADDLE_ENFORCE_LT (j, Outputs (out).size ());
577
+ Variable* in_var = scope_.FindVar (Inputs (in)[i]);
578
+ Variable* out_var = scope_.FindVar (Outputs (out)[j]);
579
+ if (!in_var->IsType <LoDTensor>()) return ;
580
+ PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
581
+ " The %d-th output of Output(%s) must be LoDTensor." , j, out);
582
+ auto in_tensor = in_var->Get <LoDTensor>();
583
+ auto * out_tensor = out_var->GetMutable <LoDTensor>();
558
584
out_tensor->set_layout (in_tensor.layout ());
559
- }
585
+ }
560
586
561
- void RuntimeInferShapeContext::ShareLayout (const std::string& in,
562
- const std::string& out, size_t i,
563
- size_t j) const {
564
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
565
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
566
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
567
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
568
- if (!in_var->IsType <LoDTensor>()) return ;
569
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
570
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
571
- auto in_tensor = in_var->Get <LoDTensor>();
572
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
573
- out_tensor->set_layout (in_tensor.layout ());
574
- }
575
-
576
- DDim RuntimeInferShapeContext::GetDim (const std::string& name) const {
577
- Variable* var = scope_.FindVar (name);
578
- PADDLE_ENFORCE_NOT_NULL (var);
579
- if (var->IsType <LoDTensor>()) {
580
- return var->Get <LoDTensor>().dims ();
581
- } else if (var->IsType <SelectedRows>()) {
582
- return var->Get <SelectedRows>().GetCompleteDims ();
583
- } else {
584
- PADDLE_THROW (
585
- " Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
586
- " type_id is %s." ,
587
- name, var->Type ().name ());
587
+ bool IsRuntime () const override { return true ; }
588
+
589
+ protected:
590
+ DDim GetDim (const std::string& name) const override {
591
+ Variable* var = scope_.FindVar (name);
592
+ PADDLE_ENFORCE_NOT_NULL (var);
593
+ if (var->IsType <LoDTensor>()) {
594
+ return var->Get <LoDTensor>().dims ();
595
+ } else if (var->IsType <SelectedRows>()) {
596
+ return var->Get <SelectedRows>().GetCompleteDims ();
597
+ } else {
598
+ PADDLE_THROW (
599
+ " Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
600
+ " type_id is %s." ,
601
+ name, var->Type ().name ());
602
+ }
588
603
}
589
- }
590
604
591
- void RuntimeInferShapeContext::SetDim (const std::string& name,
592
- const DDim& dim) {
593
- Variable* var = scope_.FindVar (name);
594
- if (var->IsType <LoDTensor>()) {
595
- var->GetMutable <LoDTensor>()->Resize (dim);
596
- } else if (var->IsType <SelectedRows>()) {
597
- var->GetMutable <SelectedRows>()->set_height (dim[0 ]);
598
- } else {
599
- PADDLE_THROW (" Variable %s type_id %s, expect LoDTensor/SelectedRows." , name,
600
- var->Type ().name ());
605
+ std::vector<DDim> GetRepeatedDims (const std::string& name) const override {
606
+ PADDLE_THROW (" Only compile time support this method" );
601
607
}
602
- }
608
+
609
+ void SetDim (const std::string& name, const DDim& dim) override {
610
+ Variable* var = scope_.FindVar (name);
611
+ if (var->IsType <LoDTensor>()) {
612
+ var->GetMutable <LoDTensor>()->Resize (dim);
613
+ } else if (var->IsType <SelectedRows>()) {
614
+ var->GetMutable <SelectedRows>()->set_height (dim[0 ]);
615
+ } else {
616
+ PADDLE_THROW (" Variable %s type_id %s, expect LoDTensor/SelectedRows." ,
617
+ name, var->Type ().name ());
618
+ }
619
+ }
620
+
621
+ void SetRepeatedDims (const std::string& name,
622
+ const std::vector<DDim>& dims) override {
623
+ PADDLE_THROW (" Only compile time support this method" );
624
+ }
625
+
626
+ proto::VarType::Type GetVarType (const std::string& name) const override {
627
+ auto * var = scope_.FindVar (name);
628
+ return ToVarType (var->Type ());
629
+ }
630
+
631
+ InferShapeVarPtr GetVarPtr (const std::string& name) override {
632
+ return scope_.FindVar (name);
633
+ }
634
+
635
+ private:
636
+ const OperatorBase& op_;
637
+ const Scope& scope_;
638
+ };
603
639
604
640
static void CheckTensorNANOrInf (const std::string& name,
605
641
const framework::Tensor& tensor) {
0 commit comments