Skip to content

Commit d683d32

Browse files
author
Feiyu Chan
authored
add coalesce_tensor into white list when checking re-creation of parameters (#31800) (#31916)
1 parent 32df1fb commit d683d32

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

python/paddle/fluid/framework.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2965,7 +2965,11 @@ def _is_inited_by(block, var):
29652965
# In startup_program, "c_broadcast" and "c_sync_comm_stream"
29662966
# are treated as initialization ops that cause error.
29672967
# Think of "c_broadcast" and "c_sync_comm_stream" as a special case here.
2968-
if op.type in ["c_broadcast", "c_sync_comm_stream"]:
2968+
# NOTE: "coalesce_tensor" is a special case for rnn with cudnn support
2969+
if op.type in [
2970+
"c_broadcast", "c_sync_comm_stream",
2971+
"coalesce_tensor"
2972+
]:
29692973
continue
29702974
init_ops.append(op)
29712975
return init_ops
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from unittest import TestCase
17+
18+
19+
def create_model():
20+
hidden_size = 32
21+
bilstm = paddle.nn.LSTM(
22+
hidden_size, hidden_size, num_layers=1, direction='bidirectional')
23+
return bilstm
24+
25+
26+
class TestRNNProgramClone(TestCase):
27+
def setUp(self):
28+
paddle.enable_static()
29+
30+
def test_rnn_with_cudnn_clone(self):
31+
train_program = paddle.static.Program()
32+
test_program = paddle.static.Program()
33+
startup_prog = paddle.static.Program()
34+
35+
# test a typical case in static graph usage: create two nearly
36+
# identical program with a shared startup program to share their
37+
# parameters
38+
#
39+
# when creating a parameter, the name is checked. If there is already
40+
# a parameter with the same name, which is the output of a operator
41+
# (i.e. its creator), its re-creation is skipped.
42+
#
43+
# but if that parameter has been the output of more than one operator,
44+
# an exception is raised. For special cases, white list is added.
45+
# flattening rnn's parameters for the need to call cudnn kernel is such
46+
# a case.
47+
with paddle.static.program_guard(train_program, startup_prog):
48+
with paddle.fluid.unique_name.guard():
49+
bilstm = create_model()
50+
51+
with paddle.fluid.program_guard(test_program, startup_prog):
52+
with paddle.fluid.unique_name.guard():
53+
bilstm = create_model()

0 commit comments

Comments
 (0)