Skip to content

Commit 09f1873

Browse files
authored
[ERNIE] Add parameter converter from static to dygraph. (#1478)
1 parent f631d3e commit 09f1873

File tree

3 files changed

+60
-3
lines changed

3 files changed

+60
-3
lines changed

examples/language_model/ernie-1.0/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ python -u -m paddle.distributed.launch \
8282
- 一般而言, `global_batch_size = micro_batch_size * sharding_degree * dp_degree`。可以使用梯度累积的方式增大`global_batch_size`。设置`global_batch_size`为理论值的整数倍是,默认启用梯度累积。
8383
- 训练断点重启,直接启动即可,程序会找到最新的checkpoint,开始重启训练。
8484

85+
### 其他
86+
#### 模型参数转换
87+
本示例提供了静态图训练脚本,但Paddle目前主要的使用方式是动态图。因此,本示例提供了静态图参数到动态图参数的转换脚本:
88+
89+
```python
90+
python converter/params_static_to_dygraph.py --model ernie-1.0 --path ./output/task_name/model_100000/static_vars
91+
# or
92+
python converter/params_static_to_dygraph.py --model ernie-1.0 --path ./output/task_name/model_last/static_vars.pdparams
93+
```
94+
在当前目录下,可以看到转换后的参数`ernie-1.0_converted.pdparams`, 也可以设置脚本中`--output_path`参数,指定输出路径。
95+
8596

8697
### 参考文献
8798
- [ERNIE: Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223.pdf)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import argparse
2+
import paddle
3+
from paddlenlp.transformers import AutoModel
4+
from paddlenlp.utils.log import logger
5+
6+
paddle.set_device("cpu")
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument(
9+
"--model", type=str, help="The name of pretrained weights in PaddleNLP.")
10+
parser.add_argument(
11+
"--path", type=str, help="The path of checkpoint to be loaded.")
12+
parser.add_argument(
13+
"--output_path",
14+
type=str,
15+
default=None,
16+
help="The path of checkpoint to be loaded.")
17+
args = parser.parse_args()
18+
19+
20+
def init_dygraph_with_static(model, static_params_path):
21+
from paddlenlp.utils.tools import static_params_to_dygraph
22+
static_tensor_dict = paddle.static.load_program_state(static_params_path)
23+
return static_params_to_dygraph(model, static_tensor_dict)
24+
25+
26+
def main(args):
27+
logger.info("Loading model: %s" % args.model)
28+
model = AutoModel.from_pretrained(args.model)
29+
logger.info("Loading static params and trans paramters...")
30+
model_dict = init_dygraph_with_static(model, args.path)
31+
save_name = args.output_path
32+
if save_name is None:
33+
save_name = args.model + "_converted.pdparams"
34+
if not save_name.endswith(".pdparams"):
35+
save_name += ".pdparams"
36+
logger.info("Saving converted params to %s" % save_name)
37+
paddle.save(model_dict, save_name)
38+
logger.info("New pdparams saved!")
39+
40+
41+
if __name__ == "__main__":
42+
main(args)

paddlenlp/utils/tools.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle
1615
import numpy as np
16+
import paddle
17+
from .log import logger
1718

1819

1920
def static_params_to_dygraph(model, static_tensor_dict):
@@ -34,6 +35,9 @@ def static_params_to_dygraph(model, static_tensor_dict):
3435

3536
ret_dict = dict()
3637
for n, p in state_dict.items():
38+
if p.name not in static_tensor_dict:
39+
logger.info("%s paramter is missing from you state dict." % n)
40+
continue
3741
ret_dict[n] = static_tensor_dict[p.name]
3842

3943
return ret_dict
@@ -56,7 +60,7 @@ def dygraph_params_to_static(model, dygraph_tensor_dict, topo=None):
5660
ret_dict = dict()
5761
for name, parm in state_dict.items():
5862
if name not in dygraph_tensor_dict:
59-
print("Miss \t\t", name)
63+
logger.info("%s paramter is missing from you state dict." % name)
6064
continue
6165

6266
tensor = dygraph_tensor_dict[name]
@@ -157,4 +161,4 @@ def compare_version(version, pair_version):
157161
return 1
158162
elif int(version_code) < int(pair_version_code):
159163
return -1
160-
return 0
164+
return 0

0 commit comments

Comments
 (0)