Skip to content

Commit 4b8d65a

Browse files
author
chengduo
authored
Add return_numpy for PE (#11792)
1 parent 59837ff commit 4b8d65a

File tree

6 files changed

+15
-11
lines changed

6 files changed

+15
-11
lines changed

python/paddle/fluid/executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def as_numpy(tensor):
7878
Returns:
7979
numpy.ndarray
8080
"""
81+
if isinstance(tensor, core.LoDTensorArray):
82+
return [as_numpy(t) for t in tensor]
8183
if isinstance(tensor, list):
8284
return [as_numpy(t) for t in tensor]
8385
assert isinstance(tensor, core.LoDTensor)

python/paddle/fluid/parallel_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self,
160160
build_strategy, num_trainers, trainer_id)
161161
self.scope = scope
162162

163-
def run(self, fetch_list, feed=None, feed_dict=None):
163+
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
164164
"""
165165
Run a parallel executor with fetch_list.
166166
@@ -196,6 +196,8 @@ def run(self, fetch_list, feed=None, feed_dict=None):
196196
to each device. Default None.
197197
feed_dict: Alias for feed parameter, for backward compatibility.
198198
This parameter has been deprecated. Default None.
199+
return_numpy(bool): Whether converts the fetched tensor to numpy.
200+
Default: True.
199201
200202
Returns:
201203
List: The fetched result list.
@@ -270,6 +272,9 @@ def run(self, fetch_list, feed=None, feed_dict=None):
270272
if self.is_dist:
271273
self.bcast_params()
272274

275+
if return_numpy:
276+
return executor.as_numpy(arr)
277+
273278
return [arr[i] for i in range(len(arr))]
274279

275280
def bcast_params(self):

python/paddle/fluid/tests/unittests/parallel_executor_test_base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def run_executor(exe, feed, fetch_list, program=None):
8181
begin = time.time()
8282
first_loss, = run_executor(
8383
exe=exe, feed=feed_dict, fetch_list=[loss.name])
84-
first_loss = np.array(first_loss)
8584

8685
for i in xrange(iter):
8786
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
@@ -94,8 +93,6 @@ def run_executor(exe, feed, fetch_list, program=None):
9493
print "%.4f Instance per second" % (
9594
(batch_size * iter + 2) / (end - begin))
9695

97-
last_loss = np.array(last_loss)
98-
9996
print first_loss, last_loss
10097
# self.assertGreater(first_loss[0], last_loss[0])
10198
return first_loss, last_loss

python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ def check_network_convergence(self,
169169
data = train_data()
170170
for i in xrange(10):
171171
cur_batch = next(data)
172-
print map(np.array,
173-
pe.run(feed=feeder.feed(cur_batch),
174-
fetch_list=[avg_cost.name]))[0]
172+
print pe.run(feed=feeder.feed(cur_batch),
173+
fetch_list=[avg_cost.name])[0]
175174

176175
@unittest.skip(reason="CI hangs")
177176
def test_update_sparse_parameter_all_reduce(self):

python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def parallel_exe(self, train_inputs, seed, use_cuda):
7575
fetch_list.append(k)
7676

7777
for data in train_inputs:
78-
ret = pe.run(fetch_list, feed=feeder.feed(data))
78+
ret = pe.run(fetch_list,
79+
feed=feeder.feed(data),
80+
return_numpy=True)
7981
for i in range(len(fetch_list)):
8082
assert not math.isnan(np.sum(ret[i])) and \
8183
not math.isinf(np.sum(ret[i]))
@@ -128,7 +130,7 @@ def parallel_exe(self, use_cuda, seed):
128130
use_cuda=use_cuda, loss_name=loss.name, main_program=main)
129131

130132
for batch_id, data in enumerate(reader()):
131-
loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0])
133+
loss_np = pe.run(feed=data, fetch_list=[loss.name])[0]
132134
print batch_id, loss_np
133135
if batch_id == 2:
134136
break

python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ def check_network_convergence(self, use_cuda, build_strategy=None):
7070

7171
for i in xrange(5):
7272
test_loss, = test_exe.run([loss.name], feed=feed_dict)
73-
test_loss = np.array(test_loss)
7473

7574
train_loss, = train_exe.run([loss.name], feed=feed_dict)
76-
train_loss = np.array(train_loss)
75+
7776
self.assertTrue(
7877
np.allclose(
7978
train_loss, test_loss, atol=1e-8),

0 commit comments

Comments
 (0)