@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
163
163
return tensor
164
164
165
165
166
+ def get_program_cache_key (feed , fetch_list ):
167
+ feed_var_names = feed .keys ()
168
+
169
+ def to_name_str (var ):
170
+ if isinstance (var , Variable ):
171
+ return var .desc .name ()
172
+ elif isinstance (var , str ):
173
+ return var
174
+ else :
175
+ raise TypeError (str (var ) + " should be Variable or str" )
176
+
177
+ fetch_var_names = map (to_name_str , fetch_list )
178
+
179
+ return str (feed_var_names + fetch_var_names )
180
+
181
+
166
182
class Executor (object ):
167
183
def __init__ (self , places ):
168
184
if not isinstance (places , list ) and not isinstance (places , tuple ):
@@ -232,12 +248,13 @@ def run(self,
232
248
233
249
Python executor takes a program, add feed operators and fetch operators to this program according
234
250
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
235
- the variables that user want to get after program run. Note: the executor will run all
251
+ the variables(or names) that user want to get after program run. Note: the executor will run all
236
252
operators in the program but not only the operators dependent by the fetch_list
237
253
238
254
:param program: the program that need to run, if not provied, then default_main_program will be used.
239
255
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
240
- :param fetch_list: a list of variable that user want to get, run will return them according to this list.
256
+ :param fetch_list: a list of variable or variable names that user want to get, run will return them according
257
+ to this list.
241
258
:param feed_var_name: the name for the input variable of feed Operator.
242
259
:param fetch_var_name: the name for the output variable of feed Operator.
243
260
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
@@ -247,6 +264,8 @@ def run(self,
247
264
"""
248
265
if feed is None :
249
266
feed = {}
267
+ if not isinstance (feed , dict ):
268
+ raise TypeError ("feed should be a map" )
250
269
if fetch_list is None :
251
270
fetch_list = []
252
271
@@ -260,10 +279,7 @@ def run(self,
260
279
scope = global_scope ()
261
280
262
281
program_cache = None
263
-
264
- feed_var_names = feed .keys ()
265
- fetch_var_names = [var .desc .name () for var in fetch_list ]
266
- program_cache_key = str (feed_var_names + fetch_var_names )
282
+ program_cache_key = get_program_cache_key (feed , fetch_list )
267
283
268
284
if use_program_cache :
269
285
# find program cache by cache_key
0 commit comments