@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
105
105
framework::ProgramDesc program (program_desc);
106
106
framework::Executor executor (dev_place);
107
107
108
- // rpc_service_->Reset();
109
108
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
110
109
bool exit_flag = false ;
110
+ int64_t barrier_size = param_count * fan_in;
111
111
while (!exit_flag) {
112
112
// Get from multiple trainers, we don't care about the order in which
113
113
// the gradients arrives, just add suffix 0~n and merge the gradient.
114
- rpc_service_->SetCond (kCondStart );
115
- VLOG (3 ) << " ================ start get from service ===========" ;
116
- for (size_t i = 0 ; i < param_count * fan_in; ++i) {
114
+ rpc_service_->SetCond (0 );
115
+ for (size_t i = 0 ; i < barrier_size; ++i) {
117
116
const detail::MessageWithName &v = rpc_service_->Get ();
118
117
auto grad_var_name = v.first ;
119
118
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
130
129
}
131
130
VLOG (3 ) << " recved grad: " << grad_var_name
132
131
<< " updating param: " << param_var_name;
133
- // Assume grad_var_name must appear in global scope.
134
- std::string grad_var_name_trainer;
135
132
if (fan_in > 1 ) {
136
133
grad_var_name = this ->GetGradVarNameForTrainer (grad_var_name);
137
134
}
@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
145
142
if (exit_flag) {
146
143
break ;
147
144
}
148
- // rpc_service_->Reset();
149
145
try {
150
146
executor.Run (program, &recv_scope, 0 , /* global_block*/
151
147
false /* create_local_scope*/ , false /* create_vars*/ );
152
148
} catch (std::exception &e) {
153
149
LOG (ERROR) << " run sub program error " << e.what ();
154
150
}
155
- VLOG (3 ) << " ================ run sub program end ===========" ;
156
- rpc_service_->SetCond (kCondDone );
157
- rpc_service_->WaitClientGet (param_count * fan_in);
151
+ rpc_service_->SetCond (1 );
152
+ rpc_service_->WaitClientGet (barrier_size);
158
153
grads_counter_.clear ();
159
154
} // while(true)
160
155
}
0 commit comments