Skip to content

Commit 9ac785b

Browse files
committed
check graph's validation
1 parent 50104f1 commit 9ac785b

File tree

4 files changed

+73
-2
lines changed

4 files changed

+73
-2
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
272272
* Only variables should be the leaves of graph.
273273
*/
274274
AddOutputToLeafOps(&result);
275-
276275
return std::unique_ptr<SSAGraph>(graph);
277276
}
278277

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
1514
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
15+
#include <utility>
1616

1717
namespace paddle {
1818
namespace framework {
@@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
8383
op->AddOutput(dummy_leaf);
8484
}
8585
}
86+
87+
std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck(
88+
const ProgramDesc &program) final {
89+
std::unique_ptr<SSAGraph> graph = Build(program);
90+
PADDLE_ENFORCE(IsValidGraph(graph.get()));
91+
return std::move(graph);
92+
}
93+
94+
bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const {
95+
std::unordered_map<OpHandleBase *, size_t> pending_ops;
96+
std::unordered_set<VarHandleBase *> pending_vars;
97+
std::unordered_set<VarHandleBase *> ready_vars;
98+
std::unordered_set<OpHandleBase *> ready_ops;
99+
100+
auto insert_pending_var = [&](VarHandleBase *var) {
101+
pending_vars.insert(var);
102+
if (var->generated_op_ == nullptr) {
103+
ready_vars.emplace(var);
104+
}
105+
};
106+
107+
for (auto &var_map : graph->vars_) {
108+
for (auto &name_pair : var_map) {
109+
for (auto &version_pair : name_pair.second) {
110+
insert_pending_var(version_pair.get());
111+
}
112+
}
113+
}
114+
115+
for (auto &var : graph->dep_vars_) {
116+
insert_pending_var(var.get());
117+
}
118+
119+
for (auto &op : graph->ops_) {
120+
if (op->Inputs().empty()) {
121+
ready_ops.insert(op.get());
122+
} else {
123+
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
124+
}
125+
}
126+
127+
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
128+
for (auto *op : set) {
129+
for (auto out : op->Outputs()) {
130+
ready_vars.emplace(out);
131+
}
132+
}
133+
set.clear();
134+
};
135+
136+
while (!pending_vars.empty()) {
137+
run_all_ops(ready_ops);
138+
if (ready_vars.empty()) {
139+
return false;
140+
}
141+
for (auto ready_var : ready_vars.) {
142+
pending_vars.erase(ready_var);
143+
for (auto *op : ready_var->pending_ops_) {
144+
auto &deps = --pending_ops[op];
145+
if (deps == 0) {
146+
ready_ops.insert(op);
147+
}
148+
}
149+
}
150+
ready_vars.clear();
151+
}
152+
return true;
153+
}
86154
} // namespace details
87155
} // namespace framework
88156
} // namespace paddle

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class SSAGraphBuilder {
3131
virtual ~SSAGraphBuilder() {}
3232
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
3333

34+
std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final;
35+
3436
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
3537

3638
protected:
@@ -48,6 +50,7 @@ class SSAGraphBuilder {
4850
const platform::Place &place,
4951
size_t place_offset);
5052

53+
bool IsValidGraph(const SSAGraph *graph) const;
5154
// Add an output variable (each_var_name, place, place_offset) to op_handle,
5255
// which belongs to graph
5356
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
185185
ready_vars->Push(var);
186186
}
187187
}
188+
188189
void ThreadedSSAGraphExecutor::RunOp(
189190
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
190191
auto op_run = [ready_var_q, op, this] {

0 commit comments

Comments
 (0)