Skip to content

Commit ccd53c1

Browse files
committed
Flatten all the output tensors to a list when the returned outputs contain list of list.
1 parent 5249675 commit ccd53c1

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -221,47 +221,56 @@ def measure_performance(model_call, args, compiler, profile=False):
221221

222222

223223
def check_outputs(args, expected_out, compiled_out):
224-
if isinstance(expected_out, paddle.Tensor):
225-
expected_out = [expected_out]
226-
if isinstance(compiled_out, paddle.Tensor):
227-
compiled_out = [compiled_out]
228-
229-
eager_dtypes = [None] * len(expected_out)
230-
for i, tensor in enumerate(expected_out):
231-
eager_dtypes[i] = (
232-
str(tensor.dtype).replace("paddle.", "") if tensor is not None else "None"
233-
)
234-
235-
compiled_dtypes = [None] * len(compiled_out)
236-
for i, tensor in enumerate(compiled_out):
237-
compiled_dtypes[i] = (
238-
str(tensor.dtype).replace("paddle.", "") if tensor is not None else "None"
239-
)
240-
224+
def _flatten_outputs_to_list(outs):
225+
flattened_outs = outs
226+
if isinstance(outs, paddle.Tensor):
227+
flattened_outs = [outs]
228+
else:
229+
flattened_outs = [
230+
x
231+
for out in outs
232+
for x in (out if isinstance(out, (tuple, list)) else (out,))
233+
]
234+
return flattened_outs
235+
236+
expected_out = _flatten_outputs_to_list(expected_out)
237+
compiled_out = _flatten_outputs_to_list(compiled_out)
238+
239+
def _get_output_dtypes(outs):
240+
dtypes = [
241+
str(tensor.dtype).replace("paddle.", "")
242+
if isinstance(tensor, paddle.Tensor)
243+
else None
244+
for i, tensor in enumerate(outs)
245+
]
246+
return dtypes
247+
248+
eager_dtypes = _get_output_dtypes(expected_out)
249+
compiled_dtypes = _get_output_dtypes(compiled_out)
241250
type_match = test_compiler_util.check_output_datatype(
242251
args, eager_dtypes, compiled_dtypes
243252
)
244253

245-
eager_shapes = [None] * len(expected_out)
246-
for i, tensor in enumerate(expected_out):
247-
eager_shapes[i] = tensor.shape if tensor is not None else None
248-
249-
compiled_shapes = [None] * len(compiled_out)
250-
for i, tensor in enumerate(compiled_out):
251-
compiled_shapes[i] = tensor.shape if tensor is not None else None
254+
def _get_output_shapes(outs):
255+
shapes = [
256+
tensor.shape if isinstance(tensor, paddle.Tensor) else None
257+
for i, tensor in enumerate(outs)
258+
]
259+
return shapes
252260

261+
eager_shapes = _get_output_shapes(expected_out)
262+
compiled_shapes = _get_output_shapes(compiled_out)
253263
shape_match = test_compiler_util.check_output_shape(
254264
args, eager_shapes, compiled_shapes
255265
)
256266

257-
def transfer_to_float(origin_outputs):
267+
def _transfer_to_float(origin_outputs):
258268
outputs = []
259269
for item in origin_outputs:
260-
if (
261-
item is not None
262-
and isinstance(item, paddle.Tensor)
263-
and item.dtype not in [paddle.float32, paddle.float64]
264-
):
270+
if isinstance(item, paddle.Tensor) and item.dtype not in [
271+
paddle.float32,
272+
paddle.float64,
273+
]:
265274
item = item.astype("float32")
266275
outputs.append(item)
267276
return outputs
@@ -274,8 +283,8 @@ def transfer_to_float(origin_outputs):
274283
cmp_equal_func=get_cmp_equal,
275284
)
276285

277-
expected_out_fp32 = transfer_to_float(expected_out)
278-
compiled_out_fp32 = transfer_to_float(compiled_out)
286+
expected_out_fp32 = _transfer_to_float(expected_out)
287+
compiled_out_fp32 = _transfer_to_float(compiled_out)
279288
test_compiler_util.check_allclose(
280289
args,
281290
expected_out_fp32,

0 commit comments

Comments
 (0)