Skip to content

Commit 1ee7784

Browse files
committed
add get_program_cache_key function
1 parent b63901f commit 1ee7784

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

python/paddle/fluid/executor.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
163163
return tensor
164164

165165

166+
def get_program_cache_key(feed, fetch_list):
167+
feed_var_names = feed.keys()
168+
169+
def to_name_str(var):
170+
if isinstance(var, Variable):
171+
return var.desc.name()
172+
elif isinstance(var, str):
173+
return var
174+
else:
175+
raise TypeError(str(var) + " should be Variable or str")
176+
177+
fetch_var_names = map(to_name_str, fetch_list)
178+
179+
return str(feed_var_names + fetch_var_names)
180+
181+
166182
class Executor(object):
167183
def __init__(self, places):
168184
if not isinstance(places, list) and not isinstance(places, tuple):
@@ -232,12 +248,13 @@ def run(self,
232248
233249
Python executor takes a program, add feed operators and fetch operators to this program according
234250
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
235-
the variables that user want to get after program run. Note: the executor will run all
251+
the variables(or names) that user want to get after program run. Note: the executor will run all
236252
operators in the program but not only the operators dependent by the fetch_list
237253
238254
:param program: the program that need to run, if not provied, then default_main_program will be used.
239255
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
240-
:param fetch_list: a list of variable that user want to get, run will return them according to this list.
256+
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
257+
to this list.
241258
:param feed_var_name: the name for the input variable of feed Operator.
242259
:param fetch_var_name: the name for the output variable of feed Operator.
243260
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
@@ -247,6 +264,8 @@ def run(self,
247264
"""
248265
if feed is None:
249266
feed = {}
267+
if not isinstance(feed, dict):
268+
raise TypeError("feed should be a map")
250269
if fetch_list is None:
251270
fetch_list = []
252271

@@ -260,10 +279,7 @@ def run(self,
260279
scope = global_scope()
261280

262281
program_cache = None
263-
264-
feed_var_names = feed.keys()
265-
fetch_var_names = [var.desc.name() for var in fetch_list]
266-
program_cache_key = str(feed_var_names + fetch_var_names)
282+
program_cache_key = get_program_cache_key(feed, fetch_list)
267283

268284
if use_program_cache:
269285
# find program cache by cache_key

python/paddle/fluid/tests/unittests/test_learning_rate_decay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
8989
exe.run(fluid.default_startup_program())
9090
for step in range(10):
9191
lr_val, = exe.run(fluid.default_main_program(),
92-
feed=[],
92+
feed={},
9393
fetch_list=[decayed_lr])
9494
python_decayed_lr = python_decay_fn(
9595
global_step=float(step), **kwargs)

0 commit comments

Comments
 (0)