Skip to content

Commit 67e3629

Browse files
authored
Cherry pick dygraph double grad depend bug (#25828)
* Fix dygraph grad bugs (#25781) * fix double grad visitid unit; test=develop * change name hash_pair to HashPair; test=develop * follow comment; test=develop * remove manual seed; test=develop * change create_graph from True to False; test=develop
1 parent 731caea commit 67e3629

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

paddle/fluid/imperative/partial_grad_engine.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
namespace paddle {
3737
namespace imperative {
3838

39+
struct HashPair {
40+
template <class T1, class T2>
41+
size_t operator()(const std::pair<T1, T2> &p) const noexcept {
42+
auto hash1 = std::hash<T1>{}(p.first);
43+
auto hash2 = std::hash<T2>{}(p.second);
44+
return hash1 ^ hash2;
45+
}
46+
};
47+
3948
/**
4049
* This function prunes the graph to get the ops between `output_targets`
4150
* and `input_target_grads`.
@@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets(
152161
target_vars = *input_target_grads;
153162

154163
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
164+
std::unordered_set<std::pair<OpBase *, OpBase *>, HashPair> op_base_visited;
155165
for (auto &endpoint_op : endpoint_ops) {
156166
op_queue.emplace(endpoint_op, nullptr);
167+
op_base_visited.emplace(endpoint_op, nullptr);
157168
}
158169

159170
while (!op_queue.empty()) {
@@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets(
207218
if (pending_op) {
208219
VLOG(10) << "Pending op of " << op->Type() << " is "
209220
<< pending_op->Type();
221+
210222
pending_ops[op].insert(pending_op);
211223
++op_deps[pending_op];
212224
} else {
@@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets(
216228
auto iter = preceding_ops.find(op);
217229
if (iter != preceding_ops.end()) {
218230
for (auto &preceding_op : iter->second) {
219-
op_queue.emplace(preceding_op, op);
231+
if (op_base_visited.count(std::make_pair(preceding_op, op)) == 0) {
232+
op_queue.emplace(preceding_op, op);
233+
op_base_visited.emplace(preceding_op, op);
234+
}
220235
}
221236
}
222237
}

python/paddle/fluid/tests/unittests/test_imperative_double_grad.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import paddle.fluid as fluid
16+
import paddle
1617
from paddle.fluid.wrapped_decorator import wrap_decorator
1718
import unittest
1819
from unittest import TestCase
@@ -295,5 +296,48 @@ def setUp(self):
295296
self.shape = [5, 10]
296297

297298

299+
class TestDygraphDoubleGradVisitedUniq(TestCase):
300+
def test_compare(self):
301+
value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2,
302+
5).astype("float32")
303+
304+
def model_f(input):
305+
linear = fluid.dygraph.Linear(5, 3, bias_attr=False)
306+
for i in range(10):
307+
if i == 0:
308+
out = linear(input)
309+
else:
310+
out = out + linear(input)
311+
return out
312+
313+
backward_strategy = fluid.dygraph.BackwardStrategy()
314+
backward_strategy.sort_sum_gradient = True
315+
with fluid.dygraph.guard():
316+
fluid.default_startup_program().random_seed = 123
317+
fluid.default_main_program().random_seed = 123
318+
a = fluid.dygraph.to_variable(value)
319+
a.stop_gradient = False
320+
321+
out = model_f(a)
322+
323+
dx=fluid.dygraph.grad(outputs=[out],inputs=[a],create_graph=False,retain_graph=False, \
324+
only_inputs=True,allow_unused=False, backward_strategy=backward_strategy)
325+
326+
grad_1 = dx[0].numpy()
327+
328+
with fluid.dygraph.guard():
329+
fluid.default_startup_program().random_seed = 123
330+
fluid.default_main_program().random_seed = 123
331+
a = fluid.dygraph.to_variable(value)
332+
a.stop_gradient = False
333+
334+
out = model_f(a)
335+
out.backward(backward_strategy)
336+
337+
grad_2 = a.gradient()
338+
339+
self.assertTrue(np.array_equal(grad_1, grad_2))
340+
341+
298342
if __name__ == '__main__':
299343
unittest.main()

0 commit comments

Comments
 (0)