Skip to content

Commit c91de28

Browse files
jacquesqiaoreyoung
authored andcommitted
CompileTime InferShape should find var recursively in stack of blocks (#4998)
* recursive find var in BlockDesc * add HasVarRecursive and FindVarRecursive to BlockDesc * fix FindVarRecursive
1 parent 54ffafa commit c91de28

File tree

6 files changed

+54
-9
lines changed

6 files changed

+54
-9
lines changed

paddle/framework/block_desc.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ bool BlockDescBind::HasVar(const std::string &name) const {
4141
return vars_.find(name) != vars_.end();
4242
}
4343

44+
VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
45+
auto it = vars_.find(name);
46+
if (it == vars_.end()) {
47+
return Parent() == kNoneBlockIndex ? nullptr
48+
: ParentBlock()->FindVarRecursive(name);
49+
}
50+
return it->second.get();
51+
}
52+
53+
bool BlockDescBind::HasVarRecursive(const std::string &name) const {
54+
return FindVarRecursive(name) != nullptr;
55+
}
56+
4457
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
4558
std::vector<VarDescBind *> res;
4659
for (const auto &p : vars_) {
@@ -97,7 +110,7 @@ void BlockDescBind::Flush() {
97110
}
98111

99112
BlockDescBind *BlockDescBind::ParentBlock() const {
100-
if (this->desc_->parent_idx() == -1) {
113+
if (this->desc_->parent_idx() == kNoneBlockIndex) {
101114
return nullptr;
102115
}
103116
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));

paddle/framework/block_desc.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <vector>
2222

2323
#include "paddle/framework/op_desc.h"
24+
#include "paddle/framework/proto_desc.h"
2425
#include "paddle/framework/var_desc.h"
2526
#include "paddle/platform/macros.h"
2627

@@ -56,6 +57,10 @@ class BlockDescBind {
5657

5758
bool HasVar(const std::string &var_name) const;
5859

60+
VarDescBind *FindVarRecursive(const std::string &name_bytes) const;
61+
62+
bool HasVarRecursive(const std::string &var_name) const;
63+
5964
std::set<std::string> LocalVarNames() const {
6065
std::set<std::string> var_names;
6166
for (auto &var : vars_) {

paddle/framework/operator.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
334334
"Input(%s) should have only one value, "
335335
"but it have %d now",
336336
name, length);
337-
return block_.HasVar(input_names[0]);
337+
return block_.HasVarRecursive(input_names[0]);
338338
}
339339

340340
bool HasOutput(const std::string& name) const override {
@@ -347,7 +347,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
347347
"Output(%s) should have only one value, "
348348
"but it have %d now",
349349
name, length);
350-
return block_.HasVar(output_names[0]);
350+
return block_.HasVarRecursive(output_names[0]);
351351
}
352352

353353
bool HasInputs(const std::string& name) const override {
@@ -356,7 +356,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
356356
return false;
357357
}
358358
for (auto& input : input_names) {
359-
if (!block_.HasVar(input)) return false;
359+
if (!block_.HasVarRecursive(input)) return false;
360360
}
361361
return true;
362362
}
@@ -367,7 +367,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
367367
return false;
368368
}
369369
for (auto& output : output_names) {
370-
if (!block_.HasVar(output)) return false;
370+
if (!block_.HasVarRecursive(output)) return false;
371371
}
372372
return true;
373373
}
@@ -414,11 +414,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
414414

415415
private:
416416
DDim GetDim(const std::string& name) const override {
417-
return framework::make_ddim(block_.FindVar(name)->Shape());
417+
return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
418418
}
419419

420420
void SetDim(const std::string& name, const DDim& dim) override {
421-
block_.FindVar(name)->SetShape(framework::vectorize(dim));
421+
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
422422
}
423423

424424
const OpDescBind& op_;

paddle/framework/program_desc.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ ProgramDesc *ProgramDescBind::Proto() {
3535

3636
ProgramDescBind::ProgramDescBind() {
3737
auto *block = prog_.mutable_blocks()->Add();
38-
block->set_idx(0);
39-
block->set_parent_idx(-1);
38+
block->set_idx(kRootBlockIndex);
39+
block->set_parent_idx(kNoneBlockIndex);
4040
blocks_.emplace_back(new BlockDescBind(this, block));
4141
}
4242

paddle/framework/program_desc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <memory>
1818
#include <vector>
1919
#include "paddle/framework/framework.pb.h"
20+
#include "paddle/framework/proto_desc.h"
2021
#include "paddle/platform/macros.h"
2122

2223
namespace paddle {

paddle/framework/proto_desc.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
// The Index of first Block in Program. also called root block.
21+
constexpr int kRootBlockIndex = 0;
22+
// The Parent Index of root Block, this block does not exist.
23+
constexpr int kNoneBlockIndex = -1;
24+
25+
} // namespace framework
26+
} // namespace paddle

0 commit comments

Comments
 (0)