Skip to content

Commit 22df230

Browse files
committed
rename 'feed_dict' in ParallelExecutor.run() to 'feed'
1 parent a78b928 commit 22df230

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

python/paddle/fluid/parallel_executor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(self,
6161
main_program=test_program,
6262
share_vars_from=train_exe)
6363
64-
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
65-
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
64+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
65+
test_loss, = test_exe.run([loss.name], feed=feed_dict)
6666
"""
6767

6868
self._places = []
@@ -123,22 +123,23 @@ def __init__(self,
123123
allow_op_delay)
124124
self.scope = scope
125125

126-
def run(self, fetch_list, feed_dict={}):
126+
def run(self, fetch_list, feed={}, feed_dict={}):
127127
"""
128128
:param fetch_list: A list of variable names that will be fetched.
129-
:param feed_dict: A dict mapping for feed variable name to LoDTensor
129+
:param feed: A dict mapping for feed variable name to LoDTensor
130130
or numpy array.
131131
:return: fetched value list.
132132
"""
133-
if not isinstance(feed_dict, dict):
134-
raise TypeError("feed_dict should be a dict")
133+
feed = feed_dict
134+
if not isinstance(feed, dict):
135+
raise TypeError("feed should be a dict")
135136

136137
feed_tensor_dict = {}
137-
for i, feed_name in enumerate(feed_dict):
138-
feed_tensor = feed_dict[feed_name]
138+
for i, feed_name in enumerate(feed):
139+
feed_tensor = feed[feed_name]
139140
if not isinstance(feed_tensor, core.LoDTensor):
140141
feed_tensor = core.LoDTensor()
141-
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
142+
feed_tensor.set(feed[feed_name], self._act_places[0])
142143
feed_tensor_dict[feed_name] = feed_tensor
143144

144145
fetch_var_name = '@FETCHED_VAR_NAME@'

0 commit comments

Comments
 (0)