|
| 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import paddle |
| 16 | + |
| 17 | + |
| 18 | +def convert_params_for_cell(np_cell, paddle_cell): |
| 19 | + state = np_cell.parameters |
| 20 | + for k, v in paddle_cell.named_parameters(): |
| 21 | + v.set_value(state[k]) |
| 22 | + |
| 23 | + |
| 24 | +def convert_params_for_cell_static(np_cell, paddle_cell, place): |
| 25 | + state = np_cell.parameters |
| 26 | + for k, v in paddle_cell.named_parameters(): |
| 27 | + scope = paddle.static.global_scope() |
| 28 | + tensor = scope.find_var(v.name).get_tensor() |
| 29 | + tensor.set(state[k], place) |
| 30 | + |
| 31 | + |
| 32 | +def convert_params_for_net(np_net, paddle_net): |
| 33 | + for np_layer, paddle_layer in zip(np_net, paddle_net): |
| 34 | + if hasattr(np_layer, "cell"): |
| 35 | + convert_params_for_cell(np_layer.cell, paddle_layer.cell) |
| 36 | + else: |
| 37 | + convert_params_for_cell(np_layer.cell_fw, paddle_layer.cell_fw) |
| 38 | + convert_params_for_cell(np_layer.cell_bw, paddle_layer.cell_bw) |
| 39 | + |
| 40 | + |
| 41 | +def convert_params_for_net_static(np_net, paddle_net, place): |
| 42 | + for np_layer, paddle_layer in zip(np_net, paddle_net): |
| 43 | + if hasattr(np_layer, "cell"): |
| 44 | + convert_params_for_cell_static(np_layer.cell, paddle_layer.cell, place) |
| 45 | + else: |
| 46 | + convert_params_for_cell_static( |
| 47 | + np_layer.cell_fw, paddle_layer.cell_fw, place |
| 48 | + ) |
| 49 | + convert_params_for_cell_static( |
| 50 | + np_layer.cell_bw, paddle_layer.cell_bw, place |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +def get_params_for_cell(np_cell, num_layers, idx): |
| 55 | + state = np_cell.parameters |
| 56 | + weight_list = [ |
| 57 | + (f"{num_layers}.weight_{idx}", state["weight_ih"]), |
| 58 | + (f"{num_layers}.weight_{idx + 1}", state["weight_hh"]), |
| 59 | + ] |
| 60 | + bias_list = [ |
| 61 | + (f"{num_layers}.bias_{idx}", state["bias_ih"]), |
| 62 | + (f"{num_layers}.bias_{idx + 1}", state["bias_hh"]), |
| 63 | + ] |
| 64 | + return weight_list, bias_list |
| 65 | + |
| 66 | + |
| 67 | +def get_params_for_net(np_net): |
| 68 | + weight_list = [] |
| 69 | + bias_list = [] |
| 70 | + for layer_idx, np_layer in enumerate(np_net): |
| 71 | + if hasattr(np_layer, "cell"): |
| 72 | + weight, bias = get_params_for_cell(np_layer.cell, layer_idx, 0) |
| 73 | + for w, b in zip(weight, bias): |
| 74 | + weight_list.append(w) |
| 75 | + bias_list.append(b) |
| 76 | + else: |
| 77 | + for count, cell in enumerate([np_layer.cell_fw, np_layer.cell_bw]): |
| 78 | + weight, bias = get_params_for_cell(cell, layer_idx, count * 2) |
| 79 | + for w, b in zip(weight, bias): |
| 80 | + weight_list.append(w) |
| 81 | + bias_list.append(b) |
| 82 | + |
| 83 | + weight_list.extend(bias_list) |
| 84 | + return weight_list |
0 commit comments