@@ -61,8 +61,8 @@ def __init__(self,
61
61
main_program=test_program,
62
62
share_vars_from=train_exe)
63
63
64
- train_loss, = train_exe.run([loss.name], feed_dict =feed_dict)
65
- test_loss, = test_exe.run([loss.name], feed_dict =feed_dict)
64
+ train_loss, = train_exe.run([loss.name], feed =feed_dict)
65
+ test_loss, = test_exe.run([loss.name], feed =feed_dict)
66
66
"""
67
67
68
68
self ._places = []
@@ -123,22 +123,23 @@ def __init__(self,
123
123
allow_op_delay )
124
124
self .scope = scope
125
125
126
- def run (self , fetch_list , feed_dict = {}):
126
+ def run (self , fetch_list , feed = {}, feed_dict = {}):
127
127
"""
128
128
:param fetch_list: A list of variable names that will be fetched.
129
- :param feed_dict : A dict mapping for feed variable name to LoDTensor
129
+ :param feed : A dict mapping for feed variable name to LoDTensor
130
130
or numpy array.
131
131
:return: fetched value list.
132
132
"""
133
- if not isinstance (feed_dict , dict ):
134
- raise TypeError ("feed_dict should be a dict" )
133
+ feed = feed_dict
134
+ if not isinstance (feed , dict ):
135
+ raise TypeError ("feed should be a dict" )
135
136
136
137
feed_tensor_dict = {}
137
- for i , feed_name in enumerate (feed_dict ):
138
- feed_tensor = feed_dict [feed_name ]
138
+ for i , feed_name in enumerate (feed ):
139
+ feed_tensor = feed [feed_name ]
139
140
if not isinstance (feed_tensor , core .LoDTensor ):
140
141
feed_tensor = core .LoDTensor ()
141
- feed_tensor .set (feed_dict [feed_name ], self ._act_places [0 ])
142
+ feed_tensor .set (feed [feed_name ], self ._act_places [0 ])
142
143
feed_tensor_dict [feed_name ] = feed_tensor
143
144
144
145
fetch_var_name = '@FETCHED_VAR_NAME@'
0 commit comments