Skip to content

Commit bead0c2

Browse files
[Complex] Fix complex tensor print (PaddlePaddle#76380)
* fix complex tensor display
1 parent 7bc4bc4 commit bead0c2

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

python/paddle/tensor/to_string.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,20 @@ def _format_item(np_var, max_width=0, signed=False):
136136
item_str = f'{np_var:.0f}.'
137137
else:
138138
item_str = f'{np_var:.{DEFAULT_PRINT_OPTIONS.precision}f}'
139+
elif np_var.dtype == np.complex64 or np_var.dtype == np.complex128:
140+
re = np.real(np_var)
141+
im = np.imag(np_var)
142+
prec = DEFAULT_PRINT_OPTIONS.precision
143+
if DEFAULT_PRINT_OPTIONS.sci_mode:
144+
if im >= 0:
145+
item_str = f'({re:.{prec}e}+{im:.{prec}e}j)'
146+
else:
147+
item_str = f'({re:.{prec}e}{im:.{prec}e}j)'
148+
else:
149+
if im >= 0:
150+
item_str = f'({re:.{prec}f}+{im:.{prec}f}j)'
151+
else:
152+
item_str = f'({re:.{prec}f}{im:.{prec}f}j)'
139153
else:
140154
item_str = f'{np_var}'
141155

@@ -311,7 +325,9 @@ def _format_dense_tensor(tensor, indent):
311325
):
312326
np_tensor = mask_xpu_bf16_tensor(np_tensor)
313327

314-
summary = tensor.numel() > DEFAULT_PRINT_OPTIONS.threshold
328+
summary = (
329+
np.prod(tensor.shape, dtype="int64") > DEFAULT_PRINT_OPTIONS.threshold
330+
)
315331

316332
max_width, signed = _get_max_width(_to_summary(np_tensor))
317333

test/legacy_test/test_eager_tensor.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import paddle.nn.functional as F
2525
from paddle import base
2626
from paddle.base import core
27+
from paddle.tensor.to_string import DEFAULT_PRINT_OPTIONS
2728
from paddle.utils.dlpack import DLDeviceType
2829

2930

@@ -1337,6 +1338,72 @@ def test_tensor_str_fp8_e5m2(self):
13371338

13381339
self.assertEqual(a_str, expected)
13391340

1341+
def test_tensor_str_complex64(self):
1342+
original_opt = copy.deepcopy(DEFAULT_PRINT_OPTIONS)
1343+
try:
1344+
paddle.disable_static(paddle.CPUPlace())
1345+
a = paddle.to_tensor(
1346+
[[1.5 + 1j, 1.0 - 2j], [0 - 3j, 0]], dtype="complex64"
1347+
).cpu()
1348+
paddle.set_printoptions(precision=4)
1349+
a_str = str(a)
1350+
1351+
expected = """Tensor(shape=[2, 2], dtype=complex64, place=Place(cpu), stop_gradient=True,
1352+
[[(1.5000+1.0000j), (1.0000-2.0000j)],
1353+
[(0.0000-3.0000j), (0.0000+0.0000j)]])"""
1354+
1355+
self.assertEqual(a_str, expected)
1356+
1357+
paddle.set_printoptions(precision=4, sci_mode=True)
1358+
a_str = str(a)
1359+
1360+
expected = """Tensor(shape=[2, 2], dtype=complex64, place=Place(cpu), stop_gradient=True,
1361+
[[(1.5000e+00+1.0000e+00j), (1.0000e+00-2.0000e+00j)],
1362+
[(0.0000e+00-3.0000e+00j), (0.0000e+00+0.0000e+00j)]])"""
1363+
1364+
self.assertEqual(a_str, expected)
1365+
finally:
1366+
paddle.set_printoptions(
1367+
precision=original_opt.precision,
1368+
threshold=original_opt.threshold,
1369+
edgeitems=original_opt.edgeitems,
1370+
sci_mode=original_opt.sci_mode,
1371+
linewidth=original_opt.linewidth,
1372+
)
1373+
1374+
def test_tensor_str_complex128(self):
1375+
original_opt = copy.deepcopy(DEFAULT_PRINT_OPTIONS)
1376+
try:
1377+
paddle.disable_static(paddle.CPUPlace())
1378+
a = paddle.to_tensor(
1379+
[[1.5 + 1j, 1.0 - 2j], [0 - 3j, 0]], dtype="complex128"
1380+
).cpu()
1381+
paddle.set_printoptions(precision=4)
1382+
a_str = str(a)
1383+
1384+
expected = """Tensor(shape=[2, 2], dtype=complex128, place=Place(cpu), stop_gradient=True,
1385+
[[(1.5000+1.0000j), (1.0000-2.0000j)],
1386+
[(0.0000-3.0000j), (0.0000+0.0000j)]])"""
1387+
1388+
self.assertEqual(a_str, expected)
1389+
1390+
paddle.set_printoptions(precision=4, sci_mode=True)
1391+
a_str = str(a)
1392+
1393+
expected = """Tensor(shape=[2, 2], dtype=complex128, place=Place(cpu), stop_gradient=True,
1394+
[[(1.5000e+00+1.0000e+00j), (1.0000e+00-2.0000e+00j)],
1395+
[(0.0000e+00-3.0000e+00j), (0.0000e+00+0.0000e+00j)]])"""
1396+
1397+
self.assertEqual(a_str, expected)
1398+
finally:
1399+
paddle.set_printoptions(
1400+
precision=original_opt.precision,
1401+
threshold=original_opt.threshold,
1402+
edgeitems=original_opt.edgeitems,
1403+
sci_mode=original_opt.sci_mode,
1404+
linewidth=original_opt.linewidth,
1405+
)
1406+
13401407
def test_print_tensor_dtype(self):
13411408
paddle.disable_static(paddle.CPUPlace())
13421409
a = paddle.rand([1])

0 commit comments

Comments
 (0)