Skip to content

Commit 6cfc0c1

Browse files
authored
"polish code" (#9318)
* "polish code" * "fix ci" * "fix ci" * "done"
1 parent b55dc9a commit 6cfc0c1

File tree

1 file changed

+18
-55
lines changed

1 file changed

+18
-55
lines changed

python/paddle/fluid/executor.py

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def as_numpy(tensor):
4848
assert isinstance(tensor, core.LoDTensor)
4949
lod = tensor.lod()
5050
if len(lod) > 0:
51-
raise RuntimeError(
52-
"Some of your featched tensors hold LoD information. \
51+
raise RuntimeError("Some of your fetched tensors hold LoD information. \
5352
They can not be completely cast to Python ndarray. \
5453
Please set the parameter 'return_numpy' as 'False' to \
5554
return LoDTensor itself directly.")
@@ -180,60 +179,24 @@ def to_name_str(var):
180179

181180

182181
class Executor(object):
183-
def __init__(self, places):
184-
if not isinstance(places, list) and not isinstance(places, tuple):
185-
places = [places]
186-
187-
act_places = []
188-
for each in places:
189-
p = core.Place()
190-
p.set_place(each)
191-
act_places.append(p)
192-
193-
# TODO(dzhwinter) : only use the first place
194-
self.executor = core.Executor(act_places[0])
195-
self.places = places
182+
def __init__(self, place):
183+
self.place = place
184+
p = core.Place()
185+
p.set_place(place)
186+
self.executor = core.Executor(p)
196187
self.program_caches = dict()
197188

198-
def aslodtensor(self, data):
199-
def accumulate(data):
200-
if not isinstance(data, list):
201-
return 1
202-
return sum([accumulate(sub) for sub in data])
203-
204-
def parselod(data):
205-
seq_lens = [accumulate(seq) for seq in data]
206-
cur_len = 0
207-
lod = [cur_len]
208-
for l in seq_lens:
209-
cur_len += l
210-
lod.append(cur_len)
211-
return lod
212-
213-
assert len(self.places) != 0
214-
if not isinstance(data, list):
215-
# pure tensor case
216-
tensor = core.LoDTensor()
217-
tensor.set(data, self.places[0])
218-
return tensor
219-
else:
220-
raise RuntimeError("Current implementation lacks unittests")
221-
# lodtensor case
222-
lod = []
223-
if not isinstance(data[0], list):
224-
lod.append(parselod(data))
225-
flattened_data = np.concatenate(data, axis=0).astype("int64")
226-
else:
227-
while isinstance(data[0], list):
228-
lod.append(parselod(seq))
229-
flattened_data = [item for seq in data for item in seq]
230-
data = flattened_data
231-
flattened_data = np.concatenate(data, axis=0).astype("int64")
232-
flattened_data = flattened_data.reshape([len(flattened_data), 1])
233-
tensor = core.LoDTensor()
234-
tensor.set(flattened_data, self.places[0])
235-
tensor.set_lod(lod)
236-
return tensor
189+
def as_lodtensor(self, data):
190+
if isinstance(data, list):
191+
raise RuntimeError("Some of your feed data hold LoD information. \
192+
They can not be completely cast from a list of Python \
193+
ndarray to LoDTensor. Please convert data to LoDTensor \
194+
directly before feeding the data.\
195+
")
196+
# single tensor case
197+
tensor = core.LoDTensor()
198+
tensor.set(data, self.place)
199+
return tensor
237200

238201
def _get_program_cache(self, program_cache_key):
239202
return self.program_caches.get(program_cache_key, None)
@@ -293,7 +256,7 @@ def _feed_data(self, program, feed, feed_var_name, scope):
293256
feed_target_name = op.desc.output('Out')[0]
294257
cur_feed = feed[feed_target_name]
295258
if not isinstance(cur_feed, core.LoDTensor):
296-
cur_feed = self.aslodtensor(cur_feed)
259+
cur_feed = self.as_lodtensor(cur_feed)
297260
idx = op.desc.attr('col')
298261
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
299262
else:

0 commit comments

Comments
 (0)