@@ -12,185 +12,145 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include < stdint.h>
16
15
#include < ostream>
16
+ #include < thread>
17
17
18
- #include " paddle/fluid/framework/executor.h"
19
- #include " paddle/fluid/framework/lod_tensor.h"
20
- #include " paddle/fluid/framework/op_registry.h"
21
- #include " paddle/fluid/framework/threadpool.h"
22
- #include " paddle/fluid/operators/detail/grpc_server.h"
18
+ #include " paddle/fluid/operators/listen_and_serv_op.h"
23
19
24
20
namespace paddle {
25
21
namespace operators {
26
22
27
- constexpr char kOptimizeBlock [] = " OptimizeBlock" ;
28
-
29
23
void RunServer (std::shared_ptr<detail::AsyncGRPCServer> service) {
30
24
service->RunSyncUpdate ();
31
25
VLOG (4 ) << " RunServer thread end" ;
32
26
}
33
27
34
- static void CreateTensorFromMessageType (framework::Variable *var,
35
- sendrecv::VarType var_type) {
36
- if (var_type == sendrecv::VarType::LOD_TENSOR) {
37
- var->GetMutable <framework::LoDTensor>();
38
- } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
39
- var->GetMutable <framework::SelectedRows>();
40
- } else {
41
- PADDLE_THROW (
42
- " VariableMessage type %d is not in "
43
- " [LoDTensor, SelectedRows]" ,
44
- var_type);
45
- }
28
+ ListenAndServOp::ListenAndServOp (const std::string &type,
29
+ const framework::VariableNameMap &inputs,
30
+ const framework::VariableNameMap &outputs,
31
+ const framework::AttributeMap &attrs)
32
+ : OperatorBase(type, inputs, outputs, attrs) {}
33
+
34
+ int ListenAndServOp::GetSelectedPort () {
35
+ return rpc_service_->GetSelectedPort ();
46
36
}
47
37
48
- static void ParallelExecuteBlocks (const std::vector<size_t > ¶llel_blkids,
49
- framework::Executor *executor,
50
- framework::ProgramDesc *program,
51
- framework::Scope *scope) {
52
- std::vector<std::future<void >> fs;
53
- for (size_t idx : parallel_blkids) {
54
- fs.push_back (framework::Async ([&executor, &program, &scope, idx]() {
55
- int run_block = idx; // thread local
56
- try {
57
- executor->Run (*program, scope, run_block, false , false );
58
- } catch (std::exception &e) {
59
- LOG (ERROR) << " run sub program error " << e.what ();
60
- }
61
- }));
62
- }
63
- for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
38
+ void ListenAndServOp::Stop () {
39
+ rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
40
+ server_thread_->join ();
64
41
}
65
42
66
- class ListenAndServOp : public framework ::OperatorBase {
67
- public:
68
- ListenAndServOp (const std::string &type,
69
- const framework::VariableNameMap &inputs,
70
- const framework::VariableNameMap &outputs,
71
- const framework::AttributeMap &attrs)
72
- : OperatorBase(type, inputs, outputs, attrs) {
73
- if (!rpc_service_) {
74
- std::string endpoint = Attr<std::string>(" endpoint" );
75
- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
76
- server_thread_.reset (new std::thread (RunServer, rpc_service_));
77
- }
78
- }
43
+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
44
+ const platform::Place &dev_place) const {
45
+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
46
+ auto &dev_ctx = *pool.Get (dev_place);
47
+ framework::Scope &recv_scope = scope.NewScope ();
48
+ LOG (INFO) << " created recv scope: " << &recv_scope;
79
49
80
- void Stop () override {
81
- rpc_service_-> Push (LISTEN_TERMINATE_MESSAGE );
82
- server_thread_-> join ( );
50
+ if (!rpc_service_) {
51
+ std::string endpoint = Attr<std::string>( " endpoint " );
52
+ rpc_service_. reset ( new detail::AsyncGRPCServer (endpoint) );
83
53
}
84
54
85
- void RunImpl (const framework::Scope &scope,
86
- const platform::Place &dev_place) const override {
87
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
88
- auto &dev_ctx = *pool.Get (dev_place);
89
- framework::Scope &recv_scope = scope.NewScope ();
90
-
91
- // FIXME(Yancey1989): initialize rpc server with lazy mode.
92
- rpc_service_->SetScope (&recv_scope);
93
- rpc_service_->SetDevCtx (&dev_ctx);
94
- auto ins = Inputs (" X" );
95
- auto fan_in = Attr<int >(" Fanin" );
96
-
97
- auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
98
- auto *program = block->Program ();
99
- int num_blocks = program->Size ();
100
- PADDLE_ENFORCE_GE (num_blocks, 2 ,
101
- " server program should have at least 2 blocks" );
102
-
103
- framework::Executor executor (dev_place);
104
-
105
- // TODO(qiao) set proper fields for table lookup and update
106
- rpc_service_->SetExecutor (&executor);
107
- rpc_service_->SetPrefetchBlkdId (0 );
108
- rpc_service_->SetProgram (program);
109
-
110
- // TODO(typhoonzero): change this to a while_op for every cluster-batch.
111
- bool exit_flag = false ;
112
- // Record received sparse variables, so that
113
- // we could reset those after execute optimize program
114
- std::vector<framework::Variable *> sparse_vars;
115
- while (!exit_flag) {
116
- // Get from multiple trainers, we don't care about the order in which
117
- // the gradients arrives, just add suffix 0~n and merge the gradient.
118
- rpc_service_->SetCond (0 );
119
- size_t recv_var_cnt = 0 ;
120
- int batch_barrier = 0 ;
121
- while (batch_barrier != fan_in) {
122
- const detail::ReceivedMessage v = rpc_service_->Get ();
123
- auto recv_var_name = v.first ;
124
- if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
125
- LOG (INFO) << " received terminate message and exit" ;
126
- exit_flag = true ;
127
- break ;
128
- } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
129
- VLOG (3 ) << " recv batch barrier message" ;
130
- batch_barrier++;
131
- continue ;
132
- } else {
133
- VLOG (3 ) << " received grad: " << recv_var_name;
134
- recv_var_cnt++;
135
- auto var = v.second ->GetVar ();
136
- if (var == nullptr ) {
137
- LOG (ERROR) << " Can not find server side var: " << recv_var_name;
138
- PADDLE_THROW (" Can not find server side var" );
139
- }
140
- if (var->IsType <framework::SelectedRows>()) {
141
- sparse_vars.push_back (var);
142
- }
143
- }
144
- }
145
- if (exit_flag) {
146
- rpc_service_->SetCond (1 );
147
- rpc_service_->ShutDown ();
55
+ auto ins = Inputs (" X" );
56
+ auto fan_in = Attr<int >(" Fanin" );
57
+ auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
58
+ auto *program = block->Program ();
59
+ size_t num_blocks = program->Size ();
60
+ PADDLE_ENFORCE_GE (num_blocks, 2 ,
61
+ " server program should have at least 2 blocks" );
62
+
63
+ framework::Executor executor (dev_place);
64
+
65
+ // FIXME(Yancey1989): initialize rpc server with lazy mode.
66
+ rpc_service_->SetScope (&recv_scope);
67
+ rpc_service_->SetDevCtx (&dev_ctx);
68
+ // TODO(qiao) set proper fields for table lookup and update
69
+ rpc_service_->SetExecutor (&executor);
70
+ rpc_service_->SetPrefetchBlkdId (0 );
71
+ rpc_service_->SetProgram (program);
72
+ // start the server listening after all member initialized.
73
+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
74
+ // FIXME(typhoonzero): do we need to wait until the server port is ready?
75
+ sleep (5 );
76
+
77
+ // TODO(typhoonzero): change this to a while_op for every cluster-batch.
78
+ bool exit_flag = false ;
79
+ // Record received sparse variables, so that
80
+ // we could reset those after execute optimize program
81
+ std::vector<framework::Variable *> sparse_vars;
82
+ while (!exit_flag) {
83
+ // Get from multiple trainers, we don't care about the order in which
84
+ // the gradients arrives, just add suffix 0~n and merge the gradient.
85
+ rpc_service_->SetCond (0 );
86
+ size_t recv_var_cnt = 0 ;
87
+ int batch_barrier = 0 ;
88
+ while (batch_barrier != fan_in) {
89
+ const detail::ReceivedMessage v = rpc_service_->Get ();
90
+ auto recv_var_name = v.first ;
91
+ if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
92
+ LOG (INFO) << " received terminate message and exit" ;
93
+ exit_flag = true ;
148
94
break ;
149
- }
150
-
151
- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
152
- // and this will still work.
153
-
154
- // The optimize blocks which have the same parent ID would run parallel
155
- // TODO(Yancey1989): need to use ParallelExecutor for future
156
- size_t last_parent_blkid = program->Block (1 ).Parent ();
157
- std::vector<size_t > parallel_blkids;
158
- parallel_blkids.push_back (1 );
159
- double ts = detail::GetTimestamp ();
160
- for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
161
- if (program->Block (blkid).Parent () != last_parent_blkid) {
162
- for (size_t idx : parallel_blkids) VLOG (3 ) << idx;
163
- ParallelExecuteBlocks (parallel_blkids, &executor, program,
164
- &recv_scope);
165
- parallel_blkids.clear ();
166
- last_parent_blkid = program->Block (blkid).Parent ();
95
+ } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
96
+ VLOG (3 ) << " recv batch barrier message" ;
97
+ batch_barrier++;
98
+ continue ;
99
+ } else {
100
+ VLOG (3 ) << " received grad: " << recv_var_name;
101
+ recv_var_cnt++;
102
+ auto var = v.second ->GetVar ();
103
+ if (var == nullptr ) {
104
+ LOG (ERROR) << " Can not find server side var: " << recv_var_name;
105
+ PADDLE_THROW (" Can not find server side var" );
106
+ }
107
+ if (var->IsType <framework::SelectedRows>()) {
108
+ sparse_vars.push_back (var);
167
109
}
168
- parallel_blkids.push_back (blkid);
169
- }
170
- ParallelExecuteBlocks (parallel_blkids, &executor, program, &recv_scope);
171
-
172
- VLOG (3 ) << " run all blocks spent " << detail::GetTimestamp () - ts
173
- << " (ms)" ;
174
-
175
- // Reset the received sparse variables, the sum operator would not
176
- // sum the input sparse variables which rows is empty at the next
177
- // mini-batch.
178
- // TODO(Yancey1989): move the reset action into an operator, we couldn't
179
- // have any hide logic in the operator.
180
- for (auto &var : sparse_vars) {
181
- var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
182
110
}
111
+ }
112
+ if (exit_flag) {
183
113
rpc_service_->SetCond (1 );
184
- // FIXME(typhoonzero): use another condition to sync wait clients get.
185
- rpc_service_->WaitClientGet (fan_in);
186
- sparse_vars.clear ();
187
- } // while(true)
188
- }
114
+ rpc_service_->ShutDown ();
115
+ break ;
116
+ }
189
117
190
- protected:
191
- std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
192
- std::shared_ptr<std::thread> server_thread_;
193
- };
118
+ // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
119
+ // and this will still work.
120
+
121
+ // The optimize blocks which have the same parent ID would run parallel
122
+ // TODO(Yancey1989): need to use ParallelExecutor for future
123
+ int32_t last_parent_blkid = program->Block (1 ).Parent ();
124
+ std::vector<size_t > parallel_blkids;
125
+ parallel_blkids.push_back (1 );
126
+ double ts = detail::GetTimestamp ();
127
+ for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
128
+ if (program->Block (blkid).Parent () != last_parent_blkid) {
129
+ for (size_t idx : parallel_blkids) VLOG (3 ) << idx;
130
+ ParallelExecuteBlocks (parallel_blkids, &executor, program, &recv_scope);
131
+ parallel_blkids.clear ();
132
+ last_parent_blkid = program->Block (blkid).Parent ();
133
+ }
134
+ parallel_blkids.push_back (blkid);
135
+ }
136
+ ParallelExecuteBlocks (parallel_blkids, &executor, program, &recv_scope);
137
+
138
+ VLOG (3 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
139
+
140
+ // Reset the received sparse variables, the sum operator would not
141
+ // sum the input sparse variables which rows is empty at the next
142
+ // mini-batch.
143
+ // TODO(Yancey1989): move the reset action into an operator, we couldn't
144
+ // have any hide logic in the operator.
145
+ for (auto &var : sparse_vars) {
146
+ var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
147
+ }
148
+ rpc_service_->SetCond (1 );
149
+ // FIXME(typhoonzero): use another condition to sync wait clients get.
150
+ rpc_service_->WaitClientGet (fan_in);
151
+ sparse_vars.clear ();
152
+ } // while(true)
153
+ }
194
154
195
155
class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
196
156
public:
0 commit comments