Skip to content

Commit 731caea

Browse files
authored
[Cherry-pick]Fix the double grad bug for the star gan. (#25655) (#25964)
* Fix the double grad bug for the star gan. (#25655) * update the retain_graph parameter doc. test=develop
1 parent 2a7efef commit 731caea

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

paddle/fluid/imperative/basic_engine.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
namespace paddle {
3434
namespace imperative {
3535

36-
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
36+
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
37+
bool retain_graph) {
3738
backward_strategy_ = strategy;
39+
retain_graph_ = retain_graph;
3840
init_node_ = var->GradVarBase()->GradNode();
3941
var->GradVarBase()->ClearGradNode();
4042

@@ -224,7 +226,9 @@ void BasicEngine::Execute() {
224226
need_accu_var_list_.clear();
225227

226228
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
227-
cur_op.ClearBackwardTrace();
229+
if (!retain_graph_) {
230+
cur_op.ClearBackwardTrace();
231+
}
228232
}
229233

230234
// Step 3: Collect ready ops

paddle/fluid/imperative/basic_engine.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class OpBase;
3030

3131
class BasicEngine : public Engine {
3232
public:
33-
void Init(VarBase* var, const detail::BackwardStrategy& strategy);
33+
void Init(VarBase* var, const detail::BackwardStrategy& strategy,
34+
bool retain_graph = false);
3435

3536
void Execute() override;
3637

@@ -51,6 +52,7 @@ class BasicEngine : public Engine {
5152
accumulators_;
5253
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
5354
need_accu_var_list_;
55+
bool retain_graph_;
5456
};
5557

5658
} // namespace imperative

paddle/fluid/pybind/imperative.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,11 @@ void BindImperative(py::module *m_ptr) {
694694
.def("_run_backward",
695695
[](imperative::VarBase &self,
696696
const imperative::detail::BackwardStrategy &bckst,
697-
const imperative::Tracer &tracer) {
697+
const imperative::Tracer &tracer, bool retain_graph) {
698698
// TODO(jiabin): when we impl more backward execution we can
699699
// select them
700700
auto *engine = tracer.GetEngine();
701-
engine->Init(&self, bckst);
701+
engine->Init(&self, bckst, retain_graph);
702702
VLOG(3) << "Start backward";
703703
engine->Execute();
704704
VLOG(3) << "Finish backward";

python/paddle/fluid/dygraph/varbase_patch_methods.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def set_value(self, value):
7373
framework._current_expected_place())
7474

7575
@framework.dygraph_only
76-
def backward(self, backward_strategy=None):
76+
def backward(self, backward_strategy=None, retain_graph=False):
7777
"""
7878
**Notes**:
7979
**This API is ONLY available in Dygraph mode**
@@ -82,6 +82,10 @@ def backward(self, backward_strategy=None):
8282
8383
Args:
8484
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
85+
retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
86+
like to add more ops to the built graph after calling this method(`backward`), set the parameter
87+
`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient.
88+
Defaults to False.
8589
8690
Returns:
8791
NoneType: None
@@ -113,7 +117,8 @@ def backward(self, backward_strategy=None):
113117
backward_strategy = BackwardStrategy()
114118
backward_strategy.sort_sum_gradient = False
115119

116-
self._run_backward(backward_strategy, framework._dygraph_tracer())
120+
self._run_backward(backward_strategy,
121+
framework._dygraph_tracer(), retain_graph)
117122
else:
118123
raise ValueError(
119124
"Variable.backward() is only available in DyGraph mode")

0 commit comments

Comments
 (0)