@@ -75,8 +75,8 @@ class ListenAndServOp : public framework::OperatorBase {
75
75
server_thread_->join ();
76
76
}
77
77
78
- void Run (const framework::Scope &scope,
79
- const platform::Place &dev_place) const override {
78
+ void RunImpl (const framework::Scope &scope,
79
+ const platform::Place &dev_place) const override {
80
80
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
81
81
auto &dev_ctx = *pool.Get (dev_place);
82
82
framework::Scope &recv_scope = scope.NewScope ();
@@ -101,7 +101,6 @@ class ListenAndServOp : public framework::OperatorBase {
101
101
// the gradients arrives, just add suffix 0~n and merge the gradient.
102
102
rpc_service_->SetCond (0 );
103
103
size_t recv_var_cnt = 0 ;
104
- size_t update_param_cnt = 0 ;
105
104
int batch_barrier = 0 ;
106
105
while (batch_barrier != fan_in) {
107
106
const detail::MessageWithName &v = rpc_service_->Get ();
@@ -128,37 +127,33 @@ class ListenAndServOp : public framework::OperatorBase {
128
127
}
129
128
}
130
129
}
131
- VLOG (3 ) << " recv " << recv_var_cnt << " parmeters for one barrier." ;
132
130
if (exit_flag) {
133
131
rpc_service_->ShutDown ();
134
132
}
135
- VLOG (3 ) << " run optimize graph..." ;
136
133
try {
137
134
executor.Run (*program, &recv_scope, block->ID (), /* global_block*/
138
135
false /* create_local_scope*/ , false /* create_vars*/ );
139
136
} catch (std::exception &e) {
140
137
LOG (ERROR) << " run sub program error " << e.what ();
141
138
}
142
-
143
139
// Reset the received sparse variables, the sum operator would not
144
140
// sum the input sparse variables which rows is empty at the next
145
141
// mini-batch.
146
- // TOOD (Yancey1989): move the reset action into an operator, we couldn't
142
+ // TODO (Yancey1989): move the reset action into an operator, we couldn't
147
143
// have any hide logic in the operator.
148
144
for (auto &var : sparse_vars) {
149
145
var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
150
146
}
151
147
rpc_service_->SetCond (1 );
152
- rpc_service_-> WaitClientGet (update_param_cnt);
153
- grads_counter_. clear ( );
148
+ // FIXME(typhoonzero): use another condition to sync wait clients get.
149
+ rpc_service_-> WaitClientGet (ins. size () );
154
150
sparse_vars.clear ();
155
151
} // while(true)
156
152
}
157
153
158
154
protected:
159
155
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
160
156
std::shared_ptr<std::thread> server_thread_;
161
- mutable std::unordered_map<std::string, int > grads_counter_;
162
157
};
163
158
164
159
class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
0 commit comments