Skip to content

Commit b69382b

Browse files
[ILUVATAR_GPU] Support rnn op (#1924)
1 parent d7cf477 commit b69382b

File tree

8 files changed

+1976
-10
lines changed

8 files changed

+1976
-10
lines changed

backends/iluvatar_gpu/kernels/cuda_kernels/rnn_grad_kernel.cu

Lines changed: 486 additions & 0 deletions
Large diffs are not rendered by default.

backends/iluvatar_gpu/kernels/cuda_kernels/rnn_kernel.cu

Lines changed: 465 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
file(
2+
GLOB TEST_OPS
3+
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
4+
"test_*.py")
5+
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
6+
7+
foreach(TEST_OP ${TEST_OPS})
8+
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
9+
endforeach()
10+
11+
if(NOT WIN32)
12+
set_tests_properties(test_rnn_nets_static PROPERTIES TIMEOUT 120)
13+
set_tests_properties(test_rnn_nets PROPERTIES TIMEOUT 120)
14+
endif()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
18+
def convert_params_for_cell(np_cell, paddle_cell):
19+
state = np_cell.parameters
20+
for k, v in paddle_cell.named_parameters():
21+
v.set_value(state[k])
22+
23+
24+
def convert_params_for_cell_static(np_cell, paddle_cell, place):
25+
state = np_cell.parameters
26+
for k, v in paddle_cell.named_parameters():
27+
scope = paddle.static.global_scope()
28+
tensor = scope.find_var(v.name).get_tensor()
29+
tensor.set(state[k], place)
30+
31+
32+
def convert_params_for_net(np_net, paddle_net):
33+
for np_layer, paddle_layer in zip(np_net, paddle_net):
34+
if hasattr(np_layer, "cell"):
35+
convert_params_for_cell(np_layer.cell, paddle_layer.cell)
36+
else:
37+
convert_params_for_cell(np_layer.cell_fw, paddle_layer.cell_fw)
38+
convert_params_for_cell(np_layer.cell_bw, paddle_layer.cell_bw)
39+
40+
41+
def convert_params_for_net_static(np_net, paddle_net, place):
42+
for np_layer, paddle_layer in zip(np_net, paddle_net):
43+
if hasattr(np_layer, "cell"):
44+
convert_params_for_cell_static(np_layer.cell, paddle_layer.cell, place)
45+
else:
46+
convert_params_for_cell_static(
47+
np_layer.cell_fw, paddle_layer.cell_fw, place
48+
)
49+
convert_params_for_cell_static(
50+
np_layer.cell_bw, paddle_layer.cell_bw, place
51+
)
52+
53+
54+
def get_params_for_cell(np_cell, num_layers, idx):
55+
state = np_cell.parameters
56+
weight_list = [
57+
(f"{num_layers}.weight_{idx}", state["weight_ih"]),
58+
(f"{num_layers}.weight_{idx + 1}", state["weight_hh"]),
59+
]
60+
bias_list = [
61+
(f"{num_layers}.bias_{idx}", state["bias_ih"]),
62+
(f"{num_layers}.bias_{idx + 1}", state["bias_hh"]),
63+
]
64+
return weight_list, bias_list
65+
66+
67+
def get_params_for_net(np_net):
68+
weight_list = []
69+
bias_list = []
70+
for layer_idx, np_layer in enumerate(np_net):
71+
if hasattr(np_layer, "cell"):
72+
weight, bias = get_params_for_cell(np_layer.cell, layer_idx, 0)
73+
for w, b in zip(weight, bias):
74+
weight_list.append(w)
75+
bias_list.append(b)
76+
else:
77+
for count, cell in enumerate([np_layer.cell_fw, np_layer.cell_bw]):
78+
weight, bias = get_params_for_cell(cell, layer_idx, count * 2)
79+
for w, b in zip(weight, bias):
80+
weight_list.append(w)
81+
bias_list.append(b)
82+
83+
weight_list.extend(bias_list)
84+
return weight_list

0 commit comments

Comments
 (0)