Skip to content

Commit a64844a

Browse files
author
chengduo
authored
enable PE return numpy (#11704)
1 parent 991cedb commit a64844a

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

python/paddle/fluid/executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def as_numpy(tensor):
7878
Returns:
7979
numpy.ndarray
8080
"""
81+
if isinstance(tensor, core.LoDTensorArray):
82+
return [as_numpy(t) for t in tensor]
8183
if isinstance(tensor, list):
8284
return [as_numpy(t) for t in tensor]
8385
assert isinstance(tensor, core.LoDTensor)

python/paddle/fluid/parallel_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self,
160160
build_strategy, num_trainers, trainer_id)
161161
self.scope = scope
162162

163-
def run(self, fetch_list, feed=None, feed_dict=None):
163+
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=False):
164164
"""
165165
Run a parallel executor with fetch_list.
166166
@@ -196,6 +196,8 @@ def run(self, fetch_list, feed=None, feed_dict=None):
196196
to each device. Default None.
197197
feed_dict: Alias for feed parameter, for backward compatibility.
198198
This parameter has been deprecated. Default None.
199+
return_numpy(bool): Whether converts the fetched tensor to numpy.
200+
Default: False.
199201
200202
Returns:
201203
List: The fetched result list.
@@ -270,6 +272,9 @@ def run(self, fetch_list, feed=None, feed_dict=None):
270272
if self.is_dist:
271273
self.bcast_params()
272274

275+
if return_numpy:
276+
return executor.as_numpy(arr)
277+
273278
return [arr[i] for i in range(len(arr))]
274279

275280
def bcast_params(self):

python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def parallel_exe(self, train_inputs, seed, use_cuda):
7575
fetch_list.append(k)
7676

7777
for data in train_inputs:
78-
ret = pe.run(fetch_list, feed=feeder.feed(data))
78+
ret = pe.run(fetch_list,
79+
feed=feeder.feed(data),
80+
return_numpy=True)
7981
for i in range(len(fetch_list)):
8082
assert not math.isnan(np.sum(ret[i])) and \
8183
not math.isinf(np.sum(ret[i]))

0 commit comments

Comments
 (0)