@@ -544,11 +544,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
544
544
545
545
void ShareLoD (const std::string& in, const std::string& out, size_t i = 0 ,
546
546
size_t j = 0 ) const override {
547
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
548
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
549
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
550
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
547
+ const std::vector<std::string>& inputs = Inputs (in);
548
+ const std::vector<std::string>& outputs = Outputs (out);
549
+ PADDLE_ENFORCE_LT (i, inputs.size ());
550
+ PADDLE_ENFORCE_LT (j, outputs.size ());
551
+ Variable* in_var = scope_.FindVar (inputs.at (i));
551
552
if (!in_var->IsType <LoDTensor>()) return ;
553
+ Variable* out_var = scope_.FindVar (outputs.at (j));
552
554
PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
553
555
" The %d-th output of Output(%s) must be LoDTensor." , j, out);
554
556
auto in_tensor = in_var->Get <LoDTensor>();
@@ -576,20 +578,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
576
578
out_tensor->set_layout (in_tensor.layout ());
577
579
}
578
580
579
- void ShareLayout (const std::string& in, const std::string& out, size_t i = 0 ,
580
- size_t j = 0 ) const {
581
- PADDLE_ENFORCE_LT (i, Inputs (in).size ());
582
- PADDLE_ENFORCE_LT (j, Outputs (out).size ());
583
- Variable* in_var = scope_.FindVar (Inputs (in)[i]);
584
- Variable* out_var = scope_.FindVar (Outputs (out)[j]);
585
- if (!in_var->IsType <LoDTensor>()) return ;
586
- PADDLE_ENFORCE (out_var->IsType <LoDTensor>(),
587
- " The %d-th output of Output(%s) must be LoDTensor." , j, out);
588
- auto in_tensor = in_var->Get <LoDTensor>();
589
- auto * out_tensor = out_var->GetMutable <LoDTensor>();
590
- out_tensor->set_layout (in_tensor.layout ());
591
- }
592
-
593
581
bool IsRuntime () const override { return true ; }
594
582
595
583
protected:
0 commit comments