20
20
import six
21
21
import logging
22
22
import pickle
23
+ import contextlib
23
24
from functools import reduce
24
25
25
26
import numpy as np
@@ -180,6 +181,17 @@ def _clone_var_in_block_(block, var):
180
181
persistable = True )
181
182
182
183
184
+ @contextlib .contextmanager
185
+ def _load_program_scope (main = None , startup = None , scope = None ):
186
+ prog = main if main else paddle .fluid .Program ()
187
+ startup_prog = startup if startup else paddle .fluid .Program ()
188
+ scope = scope if scope else paddle .fluid .core .Scope ()
189
+ with paddle .fluid .scope_guard (scope ):
190
+ with paddle .fluid .program_guard (prog , startup_prog ):
191
+ with paddle .fluid .unique_name .guard ():
192
+ yield
193
+
194
+
183
195
def _get_valid_program (main_program ):
184
196
if main_program is None :
185
197
main_program = default_main_program ()
@@ -1749,12 +1761,17 @@ def set_var(var, ndarray):
1749
1761
set_var (v , load_dict [v .name ])
1750
1762
1751
1763
1752
- def load_program_state (model_path ):
1764
+ def load_program_state (model_path , var_list = None ):
1753
1765
"""
1754
1766
Load program state from local file
1755
1767
1756
1768
Args:
1757
1769
model_path(str): The file prefix store the program
1770
+ var_list(list, optional): The variable list to load saved with
1771
+ [ save_params, save_persistables, save_vars ].
1772
+ Default: None.
1773
+ The var_list is only used to get name,
1774
+ will not be modified.
1758
1775
Returns:
1759
1776
state_dict(dict): the dict store Parameter and optimizer information
1760
1777
@@ -1775,14 +1792,94 @@ def load_program_state(model_path):
1775
1792
program_state = fluid.load_program_state( "./temp")
1776
1793
1777
1794
"""
1778
- parameter_file_name = model_path + ".pdparams"
1795
+ model_prefix = model_path
1796
+ if model_prefix .endswith (".pdparams" ):
1797
+ model_prefix = model_prefix [:- 9 ]
1798
+ elif model_prefix .endswith (".pdopt" ):
1799
+ model_prefix = model_prefix [:- 6 ]
1800
+ elif model_prefix .endswith (".pdmodel" ):
1801
+ model_prefix = model_prefix [:- 8 ]
1802
+
1803
+ parameter_file_name = model_prefix + ".pdparams"
1804
+ if not os .path .exists (parameter_file_name ):
1805
+ # model file saved with fluid.save is not found, try to load model file saved with
1806
+ # [save_vars, save_params, save_persistables]
1807
+ _logger .warning (
1808
+ "{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]" .
1809
+ format (parameter_file_name ))
1810
+
1811
+ var_name_list = []
1812
+ if var_list is None and os .path .isfile (model_path ):
1813
+ raise ValueError (
1814
+ "var_list can not be None when model_path is a file type" )
1815
+
1816
+ for root , dirs , files in os .walk (model_path , topdown = False ):
1817
+ for f in files :
1818
+ file_path = os .path .join (root , f )
1819
+ var_temp_name = os .path .relpath (file_path , model_path )
1820
+ var_temp_name = var_temp_name .replace ("\\ " , "/" )
1821
+ var_name_list .append (var_temp_name )
1822
+
1823
+ with _load_program_scope ():
1824
+ load_prog = Program ()
1825
+ load_block = load_prog .global_block ()
1826
+
1827
+ def clone_var_to_block (block , var ):
1828
+ if not isinstance (var , Variable ):
1829
+ raise TypeError ("value in var_list must be variable" )
1830
+ return block .create_var (
1831
+ name = var .name ,
1832
+ shape = var .shape ,
1833
+ dtype = var .dtype ,
1834
+ type = var .type ,
1835
+ lod_level = var .lod_level
1836
+ if var .desc .type () == core .VarDesc .VarType .LOD_TENSOR else
1837
+ None ,
1838
+ persistable = True )
1839
+
1840
+ loaded_var_list = []
1841
+
1842
+ if var_list is not None :
1843
+ for var in var_list :
1844
+ loaded_var_list .append (clone_var_to_block (load_block , var ))
1845
+ else :
1846
+ for var_name in var_name_list :
1847
+ loaded_var_list .append (
1848
+ load_block .create_var (
1849
+ name = var_name , persistable = True ))
1850
+
1851
+ place = paddle .fluid .CPUPlace ()
1852
+ exe = paddle .fluid .Executor (place )
1853
+
1854
+ try :
1855
+ if os .path .isfile (model_path ):
1856
+ dir_name , file_name = os .path .split (model_path )
1857
+ else :
1858
+ dir_name = model_path
1859
+ file_name = None
1860
+ load_vars (
1861
+ executor = exe ,
1862
+ dirname = dir_name ,
1863
+ vars = loaded_var_list ,
1864
+ filename = file_name )
1865
+ except :
1866
+ raise RuntimeError (
1867
+ "Failed to load model file , please make sure model file is saved with the "
1868
+ "following APIs: save_params, save_persistables, save_vars" )
1869
+ res_dict = {}
1870
+ for var in loaded_var_list :
1871
+ res_dict [var .name ] = np .asarray (paddle .fluid .global_scope (
1872
+ ).find_var (var .name ).get_tensor ())
1873
+
1874
+ return res_dict
1875
+
1779
1876
assert os .path .exists (parameter_file_name ), \
1780
1877
"Parameter file [{}] not exits" .format (parameter_file_name )
1781
1878
1782
1879
with open (parameter_file_name , 'rb' ) as f :
1783
1880
para_dict = pickle .load (f )
1784
1881
1785
- opt_file_name = model_path + ".pdopt"
1882
+ opt_file_name = model_prefix + ".pdopt"
1786
1883
if os .path .exists (opt_file_name ):
1787
1884
with open (opt_file_name , 'rb' ) as f :
1788
1885
opti_dict = pickle .load (f )
0 commit comments