Skip to content

Commit 9b9a225

Browse files
authored
Cherry-pick PR 38406, fix accumulator bug when multiple inplace OPs are executed continuously (#38406) (#38830)
Cherry pick PR #38406
1 parent 27774ee commit 9b9a225

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

paddle/fluid/imperative/basic_engine.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ void BasicEngine::PrepareGradAccumulators(
174174
if (!var) continue;
175175

176176
bool find_grad_node_of_var = false;
177-
if (var->HasGradNode()) {
177+
if (grad_pending_nodes.size()) {
178178
// Because Inplace op overwrites the grad_node of the input grad_var. So
179179
// only the information of grad_pending_node can be used to find the
180180
// grad_node of grad_var.
@@ -240,7 +240,7 @@ void BasicEngine::PrepareGradAccumulators(
240240
}
241241
}
242242

243-
if (!var->HasGradNode() || !find_grad_node_of_var) {
243+
if (!grad_pending_nodes.size() || !find_grad_node_of_var) {
244244
auto& accumulator = accumulators_[var.get()];
245245
if (!accumulator) {
246246
if (FLAGS_sort_sum_gradient) {
@@ -438,15 +438,15 @@ void BasicEngine::Execute() {
438438
continue;
439439
}
440440

441+
const auto& grad_pending_nodes = shared_cur_node->GradPendingNodes();
441442
std::unordered_map<VariableWrapper*,
442443
std::unique_ptr<GradientAccumulator>>::iterator
443444
iter;
444445
bool flag_find_grad = false;
445-
if (var->HasGradNode()) {
446+
if (grad_pending_nodes.size()) {
446447
VLOG(10) << "Find gradient of var (" << var->Name()
447448
<< ") with grad_node.";
448-
for (auto& grad_pending_node :
449-
shared_cur_node->GradPendingNodes()) {
449+
for (auto& grad_pending_node : grad_pending_nodes) {
450450
const auto& iter_grad_node =
451451
accumulators_with_grad_node_.find(grad_pending_node);
452452
if (iter_grad_node != accumulators_with_grad_node_.end()) {
@@ -458,10 +458,11 @@ void BasicEngine::Execute() {
458458
}
459459
}
460460
if (!flag_find_grad) {
461-
VLOG(6) << "Cannot find gradient of variable " << var->Name();
461+
VLOG(6) << "Cannot find gradient of variable " << var->Name()
462+
<< " in accumulators_with_grad_node_";
462463
}
463464
}
464-
if (!var->HasGradNode() || !flag_find_grad) {
465+
if (!grad_pending_nodes.size() || !flag_find_grad) {
465466
VLOG(10) << "Find gradient of var (" << var->Name()
466467
<< ") with no grad_node.";
467468
iter = accumulators_.find(var.get());

python/paddle/fluid/tests/unittests/test_inplace.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,5 +434,18 @@ def test_loss_is_inplace_var(self):
434434
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
435435

436436

437+
class TestContinuouslyInplace(unittest.TestCase):
438+
def test_continuously_inplace(self):
439+
a = paddle.rand([2, 3])
440+
a.stop_gradient = False
441+
b = a * 2
442+
443+
b.reshape_([-1])
444+
b.reshape_([2, 3])
445+
b.reshape_([-1])
446+
447+
b.backward()
448+
449+
437450
if __name__ == '__main__':
438451
unittest.main()

0 commit comments

Comments
 (0)