@@ -24,68 +24,6 @@ namespace paddle {
24
24
namespace framework {
25
25
namespace ir {
26
26
27
- std::vector<std::string> FindDistTrainSendVars (
28
- const std::vector<ir::Node *> &nodes) {
29
- std::vector<std::string> send_vars;
30
- // since parameters are all in block 0,
31
- // it's enough to only scan send ops in block 0
32
- for (auto &node : nodes) {
33
- auto op_vars = node->Op ()->InputArgumentNames ();
34
- send_vars.reserve (send_vars.size () +
35
- std::distance (op_vars.begin (), op_vars.end ()));
36
- send_vars.insert (send_vars.end (), op_vars.begin (), op_vars.end ());
37
- }
38
- return send_vars;
39
- }
40
-
41
- std::vector<std::string> FindDistTrainRecvVars (
42
- const std::vector<ir::Node *> &nodes) {
43
- std::vector<std::string> recv_vars;
44
- for (auto &node : nodes) {
45
- auto op_vars = node->Op ()->OutputArgumentNames ();
46
- recv_vars.reserve (recv_vars.size () +
47
- std::distance (op_vars.begin (), op_vars.end ()));
48
- recv_vars.insert (recv_vars.end (), op_vars.begin (), op_vars.end ());
49
- }
50
- return recv_vars;
51
- }
52
-
53
- bool IsDistTrainOp (ir::Node *node, const std::vector<std::string> &send_vars,
54
- const std::vector<std::string> &recv_vars) {
55
- if (send_vars.size () == 0 || recv_vars.size () == 0 ) {
56
- return false ;
57
- }
58
-
59
- /* *
60
- * Check any of opvars contains `.block` and in sendvars
61
- */
62
- auto checker = [](const std::vector<std::string> &opvars,
63
- const std::vector<std::string> &rpc_vars) -> bool {
64
- for (auto &var : opvars) {
65
- // a variable name with the suffix `.block` means it's a splited
66
- // variable by (DistributeTranspiler)
67
- // [python/paddle/fluid/transpiler/distribute_transpiler.py]
68
- if (var.find (" .block" ) != std::string::npos &&
69
- std::find (rpc_vars.begin (), rpc_vars.end (), var) != rpc_vars.end ()) {
70
- return true ;
71
- }
72
- }
73
- return false ;
74
- };
75
-
76
- std::vector<std::string> input_var_names;
77
- std::vector<std::string> output_var_names;
78
- for (ir::Node *input : node->inputs ) {
79
- input_var_names.push_back (input->Name ());
80
- }
81
- for (ir::Node *output : node->outputs ) {
82
- output_var_names.push_back (output->Name ());
83
- }
84
-
85
- return checker (output_var_names, send_vars) ||
86
- checker (input_var_names, recv_vars);
87
- }
88
-
89
27
Graph::Graph (const ProgramDesc &program) : program_(program) {
90
28
// Make the nodes id start from 0.
91
29
Node::ResetId ();
0 commit comments