Skip to content

Commit 88600d7

Browse files
fix lod tensor return (#101)
Co-authored-by: tangwei <[email protected]>
1 parent 5a52180 commit 88600d7

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

core/trainers/framework/runner.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,52 @@
1616

1717
import os
1818
import time
19-
import warnings
20-
import datetime
21-
19+
import numpy as np
2220
import paddle.fluid as fluid
21+
2322
from paddlerec.core.utils import envs
2423

2524
__all__ = [
2625
"RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner"
2726
]
2827

2928

29+
def as_numpy(tensor):
30+
"""
31+
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
32+
For higher dimensional sequence data, please use LoDTensor directly.
33+
34+
Examples:
35+
.. code-block:: python
36+
37+
import paddle.fluid as fluid
38+
import numpy
39+
40+
new_scope = fluid.Scope()
41+
with fluid.scope_guard(new_scope):
42+
fluid.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), fluid.CPUPlace())
43+
tensor = new_scope.find_var("data").get_tensor()
44+
fluid.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
45+
46+
Args:
47+
tensor(Variable): a instance of Tensor
48+
49+
Returns:
50+
numpy.ndarray
51+
"""
52+
if isinstance(tensor, fluid.core.LoDTensorArray):
53+
return [as_numpy(t) for t in tensor]
54+
if isinstance(tensor, list):
55+
return [as_numpy(t) for t in tensor]
56+
assert isinstance(tensor, fluid.core.LoDTensor)
57+
lod = tensor.lod()
58+
# (todo) need print lod or return it for user
59+
if tensor._is_initialized():
60+
return np.array(tensor)
61+
else:
62+
return None
63+
64+
3065
class RunnerBase(object):
3166
"""R
3267
"""
@@ -92,9 +127,6 @@ def _executor_dataloader_train(self, model_dict, context):
92127
model_class = context["model"][model_dict["name"]]["model"]
93128
program = self._get_dataloader_program(model_dict, context)
94129

95-
reader_name = model_dict["dataset_name"]
96-
fetch_vars = []
97-
fetch_alias = []
98130
fetch_period = int(
99131
envs.get_global_env("runner." + context["runner_name"] +
100132
".print_interval", 20))
@@ -103,9 +135,6 @@ def _executor_dataloader_train(self, model_dict, context):
103135
else:
104136
metrics = model_class.get_metrics()
105137

106-
if metrics:
107-
fetch_vars = metrics.values()
108-
fetch_alias = metrics.keys()
109138
metrics_varnames = []
110139
metrics_format = []
111140
metrics_format.append("{}: {{}}".format("batch"))
@@ -121,9 +150,16 @@ def _executor_dataloader_train(self, model_dict, context):
121150
with fluid.scope_guard(scope):
122151
try:
123152
while True:
124-
metrics_rets = context["exe"].run(
125-
program=program, fetch_list=metrics_varnames)
153+
metrics_tensors = context["exe"].run(
154+
program=program,
155+
fetch_list=metrics_varnames,
156+
return_numpy=False)
126157
metrics = [batch_id]
158+
159+
metrics_rets = [
160+
as_numpy(metrics_tensor)
161+
for metrics_tensor in metrics_tensors
162+
]
127163
metrics.extend(metrics_rets)
128164

129165
if batch_id % fetch_period == 0 and batch_id != 0:
@@ -248,7 +284,7 @@ def save_inference_model():
248284
fetch_varnames = envs.get_global_env(
249285
name + "save_inference_fetch_varnames", [])
250286
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
251-
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
287+
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
252288
return
253289
fetch_vars = [
254290
fluid.default_main_program().global_block().vars[varname]

core/trainers/general_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@
1919
import os
2020

2121
from paddlerec.core.utils import envs
22-
from paddlerec.core.trainer import Trainer, EngineMode, FleetMode, Device
23-
from paddlerec.core.trainers.framework.dataset import *
24-
from paddlerec.core.trainers.framework.runner import *
25-
from paddlerec.core.trainers.framework.instance import *
26-
from paddlerec.core.trainers.framework.network import *
27-
from paddlerec.core.trainers.framework.startup import *
22+
from paddlerec.core.trainer import Trainer, EngineMode, FleetMode
2823

2924

3025
class GeneralTrainer(Trainer):

0 commit comments

Comments
 (0)