Skip to content

Commit 2057f76

Browse files
authored
Enhance load program state (#22546) (#22589)
* enhance load program state; test=develop * optimize commet; test=develop
1 parent 3f4687b commit 2057f76

File tree

2 files changed

+412
-3
lines changed

2 files changed

+412
-3
lines changed

python/paddle/fluid/io.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import six
2121
import logging
2222
import pickle
23+
import contextlib
2324
from functools import reduce
2425

2526
import numpy as np
@@ -180,6 +181,17 @@ def _clone_var_in_block_(block, var):
180181
persistable=True)
181182

182183

184+
@contextlib.contextmanager
185+
def _load_program_scope(main=None, startup=None, scope=None):
186+
prog = main if main else paddle.fluid.Program()
187+
startup_prog = startup if startup else paddle.fluid.Program()
188+
scope = scope if scope else paddle.fluid.core.Scope()
189+
with paddle.fluid.scope_guard(scope):
190+
with paddle.fluid.program_guard(prog, startup_prog):
191+
with paddle.fluid.unique_name.guard():
192+
yield
193+
194+
183195
def _get_valid_program(main_program):
184196
if main_program is None:
185197
main_program = default_main_program()
@@ -1749,12 +1761,17 @@ def set_var(var, ndarray):
17491761
set_var(v, load_dict[v.name])
17501762

17511763

1752-
def load_program_state(model_path):
1764+
def load_program_state(model_path, var_list=None):
17531765
"""
17541766
Load program state from local file
17551767
17561768
Args:
17571769
model_path(str): The file prefix store the program
1770+
var_list(list, optional): The variable list to load saved with
1771+
[ save_params, save_persistables, save_vars ].
1772+
Default: None.
1773+
The var_list is only used to get name,
1774+
will not be modified.
17581775
Returns:
17591776
state_dict(dict): the dict store Parameter and optimizer information
17601777
@@ -1775,14 +1792,94 @@ def load_program_state(model_path):
17751792
program_state = fluid.load_program_state( "./temp")
17761793
17771794
"""
1778-
parameter_file_name = model_path + ".pdparams"
1795+
model_prefix = model_path
1796+
if model_prefix.endswith(".pdparams"):
1797+
model_prefix = model_prefix[:-9]
1798+
elif model_prefix.endswith(".pdopt"):
1799+
model_prefix = model_prefix[:-6]
1800+
elif model_prefix.endswith(".pdmodel"):
1801+
model_prefix = model_prefix[:-8]
1802+
1803+
parameter_file_name = model_prefix + ".pdparams"
1804+
if not os.path.exists(parameter_file_name):
1805+
# model file saved with fluid.save is not found, try to load model file saved with
1806+
# [save_vars, save_params, save_persistables]
1807+
_logger.warning(
1808+
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
1809+
format(parameter_file_name))
1810+
1811+
var_name_list = []
1812+
if var_list is None and os.path.isfile(model_path):
1813+
raise ValueError(
1814+
"var_list can not be None when model_path is a file type")
1815+
1816+
for root, dirs, files in os.walk(model_path, topdown=False):
1817+
for f in files:
1818+
file_path = os.path.join(root, f)
1819+
var_temp_name = os.path.relpath(file_path, model_path)
1820+
var_temp_name = var_temp_name.replace("\\", "/")
1821+
var_name_list.append(var_temp_name)
1822+
1823+
with _load_program_scope():
1824+
load_prog = Program()
1825+
load_block = load_prog.global_block()
1826+
1827+
def clone_var_to_block(block, var):
1828+
if not isinstance(var, Variable):
1829+
raise TypeError("value in var_list must be variable")
1830+
return block.create_var(
1831+
name=var.name,
1832+
shape=var.shape,
1833+
dtype=var.dtype,
1834+
type=var.type,
1835+
lod_level=var.lod_level
1836+
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR else
1837+
None,
1838+
persistable=True)
1839+
1840+
loaded_var_list = []
1841+
1842+
if var_list is not None:
1843+
for var in var_list:
1844+
loaded_var_list.append(clone_var_to_block(load_block, var))
1845+
else:
1846+
for var_name in var_name_list:
1847+
loaded_var_list.append(
1848+
load_block.create_var(
1849+
name=var_name, persistable=True))
1850+
1851+
place = paddle.fluid.CPUPlace()
1852+
exe = paddle.fluid.Executor(place)
1853+
1854+
try:
1855+
if os.path.isfile(model_path):
1856+
dir_name, file_name = os.path.split(model_path)
1857+
else:
1858+
dir_name = model_path
1859+
file_name = None
1860+
load_vars(
1861+
executor=exe,
1862+
dirname=dir_name,
1863+
vars=loaded_var_list,
1864+
filename=file_name)
1865+
except:
1866+
raise RuntimeError(
1867+
"Failed to load model file , please make sure model file is saved with the "
1868+
"following APIs: save_params, save_persistables, save_vars")
1869+
res_dict = {}
1870+
for var in loaded_var_list:
1871+
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
1872+
).find_var(var.name).get_tensor())
1873+
1874+
return res_dict
1875+
17791876
assert os.path.exists(parameter_file_name), \
17801877
"Parameter file [{}] not exits".format(parameter_file_name)
17811878

17821879
with open(parameter_file_name, 'rb') as f:
17831880
para_dict = pickle.load(f)
17841881

1785-
opt_file_name = model_path + ".pdopt"
1882+
opt_file_name = model_prefix + ".pdopt"
17861883
if os.path.exists(opt_file_name):
17871884
with open(opt_file_name, 'rb') as f:
17881885
opti_dict = pickle.load(f)

0 commit comments

Comments
 (0)