@@ -174,7 +174,7 @@ void BasicEngine::PrepareGradAccumulators(
174
174
if (!var) continue ;
175
175
176
176
bool find_grad_node_of_var = false ;
177
- if (var-> HasGradNode ()) {
177
+ if (grad_pending_nodes. size ()) {
178
178
// Because Inplace op overwrites the grad_node of the input grad_var. So
179
179
// only the information of grad_pending_node can be used to find the
180
180
// grad_node of grad_var.
@@ -240,7 +240,7 @@ void BasicEngine::PrepareGradAccumulators(
240
240
}
241
241
}
242
242
243
- if (!var-> HasGradNode () || !find_grad_node_of_var) {
243
+ if (!grad_pending_nodes. size () || !find_grad_node_of_var) {
244
244
auto & accumulator = accumulators_[var.get ()];
245
245
if (!accumulator) {
246
246
if (FLAGS_sort_sum_gradient) {
@@ -438,15 +438,15 @@ void BasicEngine::Execute() {
438
438
continue ;
439
439
}
440
440
441
+ const auto & grad_pending_nodes = shared_cur_node->GradPendingNodes ();
441
442
std::unordered_map<VariableWrapper*,
442
443
std::unique_ptr<GradientAccumulator>>::iterator
443
444
iter;
444
445
bool flag_find_grad = false ;
445
- if (var-> HasGradNode ()) {
446
+ if (grad_pending_nodes. size ()) {
446
447
VLOG (10 ) << " Find gradient of var (" << var->Name ()
447
448
<< " ) with grad_node." ;
448
- for (auto & grad_pending_node :
449
- shared_cur_node->GradPendingNodes ()) {
449
+ for (auto & grad_pending_node : grad_pending_nodes) {
450
450
const auto & iter_grad_node =
451
451
accumulators_with_grad_node_.find (grad_pending_node);
452
452
if (iter_grad_node != accumulators_with_grad_node_.end ()) {
@@ -458,10 +458,11 @@ void BasicEngine::Execute() {
458
458
}
459
459
}
460
460
if (!flag_find_grad) {
461
- VLOG (6 ) << " Cannot find gradient of variable " << var->Name ();
461
+ VLOG (6 ) << " Cannot find gradient of variable " << var->Name ()
462
+ << " in accumulators_with_grad_node_" ;
462
463
}
463
464
}
464
- if (!var-> HasGradNode () || !flag_find_grad) {
465
+ if (!grad_pending_nodes. size () || !flag_find_grad) {
465
466
VLOG (10 ) << " Find gradient of var (" << var->Name ()
466
467
<< " ) with no grad_node." ;
467
468
iter = accumulators_.find (var.get ());
0 commit comments