19
19
import warnings
20
20
import six
21
21
import logging
22
+ import pickle
22
23
from functools import reduce
23
24
24
25
import numpy as np
40
41
batch = paddle .batch
41
42
42
43
__all__ = [
43
- 'save_vars' , 'save_params' , 'save_persistables' , 'load_vars' , 'load_params' ,
44
- 'load_persistables' , 'save_inference_model' , 'load_inference_model' ,
45
- 'batch' , 'save' , 'load'
44
+ 'save_vars' ,
45
+ 'save_params' ,
46
+ 'save_persistables' ,
47
+ 'load_vars' ,
48
+ 'load_params' ,
49
+ 'load_persistables' ,
50
+ 'save_inference_model' ,
51
+ 'load_inference_model' ,
52
+ 'batch' ,
53
+ 'save' ,
54
+ 'load' ,
55
+ 'load_program_state' ,
56
+ 'set_program_state' ,
46
57
] + reader .__all__ + paddle .reader .__all__
47
58
48
59
_logger = get_logger (
@@ -96,7 +107,10 @@ def is_persistable(var):
96
107
97
108
98
109
def is_belong_to_optimizer (var ):
99
- return var .belong_to_optimizer
110
+ if not isinstance (var , Parameter ):
111
+ return is_persistable (var )
112
+
113
+ return False
100
114
101
115
102
116
def _clone_var_in_block_ (block , var ):
@@ -1505,15 +1519,21 @@ def save(program, model_path):
1505
1519
assert base_name != "" , \
1506
1520
"model_path MUST be format of dirname/filename [dirname\\ filename in Window], Now filename is empty str"
1507
1521
1522
+ def get_tensor (var ):
1523
+ t = global_scope ().find_var (var .name ).get_tensor ()
1524
+ return np .array (t )
1525
+
1508
1526
parameter_list = list (filter (is_parameter , program .list_vars ()))
1509
- paddle .fluid .core ._save_static_dict (model_path + ".pdparams" ,
1510
- parameter_list , global_scope ())
1527
+ param_dict = {p .name : get_tensor (p ) for p in parameter_list }
1528
+ with open (model_path + ".pdparams" , 'wb' ) as f :
1529
+ pickle .dump (param_dict , f )
1511
1530
1512
1531
optimizer_var_list = list (
1513
1532
filter (is_belong_to_optimizer , program .list_vars ()))
1514
1533
1515
- paddle .fluid .core ._save_static_dict (model_path + ".pdopt" ,
1516
- optimizer_var_list , global_scope ())
1534
+ opt_dict = {p .name : get_tensor (p ) for p in optimizer_var_list }
1535
+ with open (model_path + ".pdopt" , 'wb' ) as f :
1536
+ pickle .dump (opt_dict , f )
1517
1537
1518
1538
main_program = program .clone ()
1519
1539
program .desc .flush ()
@@ -1524,16 +1544,16 @@ def save(program, model_path):
1524
1544
f .write (program .desc .serialize_to_string ())
1525
1545
1526
1546
1527
- def load (program , model_path ):
1547
+ def load (program , model_path , executor = None ):
1528
1548
"""
1529
1549
This function filter out parameters and optimizer information from program, and then get corresponding value from file.
1530
- An exception will throw if shape or dtype of the parameters is not match between program and loaded file.
1531
-
1532
- NOTICE: This function MUST called after run start_up_program
1550
+ An exception will throw if shape or dtype of the parameters is not match.
1533
1551
1534
1552
Args:
1535
- program: The program to be load
1536
- model_path: The file prefix store the program
1553
+ program(Program): The program will be loaded
1554
+ model_path(str): The file prefix store the program
1555
+ executor(Executor, optional): The executor used for initialize the parameter
1556
+ When startup program is not run.
1537
1557
1538
1558
Returns:
1539
1559
None
@@ -1550,13 +1570,39 @@ def load(program, model_path):
1550
1570
1551
1571
"""
1552
1572
1573
+ assert executor is None or isinstance (executor , Executor )
1574
+
1553
1575
parameter_file_name = model_path + ".pdparams"
1554
1576
assert os .path .exists (parameter_file_name ), \
1555
- "Parameter file [{}] not exits" .format ( parameter_file_name )
1577
+ "Parameter file [{}] not exits" .format (parameter_file_name )
1578
+
1579
+ def set_var (var , ndarray ):
1580
+ t = global_scope ().find_var (var .name ).get_tensor ()
1581
+ p = t ._place ()
1582
+ if p .is_cpu_place ():
1583
+ place = paddle .fluid .CPUPlace ()
1584
+ elif p .is_cuda_pinned_place ():
1585
+ place = paddle .fluid .CUDAPinnedPlace ()
1586
+ else :
1587
+ p = paddle .fluid .core .Place ()
1588
+ p .set_place (t ._place ())
1589
+ place = paddle .fluid .CUDAPlace (p .gpu_device_id ())
1590
+
1591
+ t .set (ndarray , place )
1556
1592
1557
1593
parameter_list = list (filter (is_parameter , program .list_vars ()))
1558
- paddle .fluid .core ._load_static_dict (parameter_file_name , parameter_list ,
1559
- global_scope ())
1594
+
1595
+ if executor :
1596
+ paddle .fluid .core ._create_loaded_parameter (parameter_list ,
1597
+ global_scope (),
1598
+ executor ._default_executor )
1599
+ with open (parameter_file_name , 'rb' ) as f :
1600
+ load_dict = pickle .load (f )
1601
+ for v in parameter_list :
1602
+ assert v .name in load_dict , \
1603
+ "Can not find [{}] in model file [{}]" .format (
1604
+ v .name , parameter_file_name )
1605
+ set_var (v , load_dict [v .name ])
1560
1606
1561
1607
optimizer_var_list = list (
1562
1608
filter (is_belong_to_optimizer , program .list_vars ()))
@@ -1565,5 +1611,138 @@ def load(program, model_path):
1565
1611
opt_file_name = model_path + ".pdopt"
1566
1612
assert os .path .exists (opt_file_name ), \
1567
1613
"Optimizer file [{}] not exits" .format ( opt_file_name )
1568
- paddle .fluid .core ._load_static_dict (opt_file_name , optimizer_var_list ,
1569
- global_scope ())
1614
+
1615
+ if executor :
1616
+ paddle .fluid .core ._create_loaded_parameter (
1617
+ optimizer_var_list , global_scope (), executor ._default_executor )
1618
+
1619
+ with open (opt_file_name , 'rb' ) as f :
1620
+ load_dict = pickle .load (f )
1621
+ for v in optimizer_var_list :
1622
+ assert v .name in load_dict , \
1623
+ "Can not find [{}] in model file [{}]" .format (
1624
+ v .name , opt_file_name )
1625
+ set_var (v , load_dict [v .name ])
1626
+
1627
+
1628
+ def load_program_state (model_path ):
1629
+ """
1630
+ Load program state from local file
1631
+
1632
+ Args:
1633
+ model_path(str): The file prefix store the program
1634
+ Returns:
1635
+ state_dict(dict): the dict store Parameter and optimizer information
1636
+
1637
+ Examples:
1638
+ .. code-block:: python
1639
+
1640
+ import paddle.fluid as fluid
1641
+ x = fluid.data( name="x", shape=[10, 10], dtype='float32')
1642
+ y = fluid.layers.fc( x, 10)
1643
+ z = fluid.layers.fc( y, 10)
1644
+
1645
+ place = fluid.CPUPlace()
1646
+ exe = fluid.Executor(place)
1647
+ exe.run( fluid.default_startup_program() )
1648
+ prog = fluid.default_main_program()
1649
+
1650
+ fluid.save( prog, "./temp")
1651
+ program_state = fluid.load_program_state( "./temp")
1652
+
1653
+ fluid.set_program_state( prog, program_state)
1654
+
1655
+ """
1656
+ parameter_file_name = model_path + ".pdparams"
1657
+ assert os .path .exists (parameter_file_name ), \
1658
+ "Parameter file [{}] not exits" .format ( parameter_file_name )
1659
+
1660
+ with open (parameter_file_name , 'rb' ) as f :
1661
+ para_dict = pickle .load (f )
1662
+
1663
+ opt_file_name = model_path + ".pdopt"
1664
+ if os .path .exists (opt_file_name ):
1665
+ with open (opt_file_name , 'rb' ) as f :
1666
+ opti_dict = pickle .load (f )
1667
+
1668
+ para_dict .update (opti_dict )
1669
+
1670
+ return para_dict
1671
+
1672
+
1673
+ def set_program_state (program , state_dict ):
1674
+ """
1675
+ Set program parameter from state_dict
1676
+
1677
+ An exception will throw if shape or dtype of the parameters is not match.
1678
+
1679
+ NOTICE: This function MUST called after run start_up_program
1680
+
1681
+ Args:
1682
+ program(Program): The program to be set
1683
+ state_dict(dict): the dict store Parameter and optimizer information
1684
+ Returns:
1685
+ None
1686
+
1687
+ Examples:
1688
+ .. code-block:: python
1689
+
1690
+ import paddle.fluid as fluid
1691
+ x = fluid.data( name="x", shape=[10, 10], dtype='float32')
1692
+ y = fluid.layers.fc( x, 10)
1693
+ z = fluid.layers.fc( y, 10)
1694
+
1695
+ place = fluid.CPUPlace()
1696
+ exe = fluid.Executor(place)
1697
+ exe.run( fluid.default_startup_program() )
1698
+ prog = fluid.default_main_program()
1699
+
1700
+ fluid.save( prog, "./temp")
1701
+ program_state = fluid.load_program_state( "./temp")
1702
+
1703
+ """
1704
+ parameter_list = list (filter (is_persistable , program .list_vars ()))
1705
+
1706
+ used_para_list = {}
1707
+ for para in parameter_list :
1708
+ var_temp = paddle .fluid .global_scope ().find_var (para .name )
1709
+ assert var_temp != None , \
1710
+ "Variable [ {} ] Not found, Please make sure run startup program" .format ( para .name )
1711
+ if para .name in state_dict :
1712
+ # set value from state dict
1713
+ orig_para_np = np .array (var_temp .get_tensor ())
1714
+ new_para_np = state_dict [para .name ]
1715
+ assert orig_para_np .shape == new_para_np .shape , \
1716
+ "Shape not matching: the Program requires a parameter with a shape of ({}), " \
1717
+ "while the loaded parameter (namely [ {} ]) has a shape of ({})." \
1718
+ .format (orig_para_np .shape , para .name , new_para_np .shape )
1719
+ assert orig_para_np .dtype == new_para_np .dtype , \
1720
+ "Dtype not matching: the Program requires a parameter with a dtype of ({}), " \
1721
+ "while the loaded parameter (namely [ {} ]) has a dtype of ({})." \
1722
+ .format (orig_para_np .dtype , para .name , new_para_np .dtype )
1723
+
1724
+ ten = var_temp .get_tensor ()
1725
+ ten_place = ten ._place ()
1726
+
1727
+ assert ten_place .is_gpu_place () or ten_place .is_cpu_place (), \
1728
+ "Place not support, only support CPUPlace and GPUPlace, now is {}" .format ( str (ten_place ))
1729
+ py_place = paddle .fluid .CPUPlace ()
1730
+ if ten_place .is_cuda_pinned_place ():
1731
+ place = paddle .fluid .CUDAPinnedPlace ()
1732
+ elif ten_place .is_gpu_place ():
1733
+ p = paddle .fluid .core .Place ()
1734
+ p .set_place (ten_place )
1735
+ py_place = paddle .fluid .CUDAPlace (p .gpu_device_id ())
1736
+
1737
+ ten .set (new_para_np , py_place )
1738
+
1739
+ used_para_list [para .name ] = 1
1740
+
1741
+ unused_para_list = []
1742
+ for k , v in state_dict .items ():
1743
+ if k not in used_para_list :
1744
+ unused_para_list .append (k )
1745
+ if len (unused_para_list ) > 0 :
1746
+ warnings .warn (
1747
+ "This list is not set, Because of Paramerter not found in program. There are: {}" .
1748
+ format (" " .join (unused_para_list )))
0 commit comments