Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033)
* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042).
* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042)
* Updated `repr` to show the shape of the abbreviated arrays and show the shape and data type of zero-size arrays [#2067](https://github.com/IntelPython/dpctl/pull/2067)

### Fixed

Expand Down
33 changes: 23 additions & 10 deletions dpctl/tensor/_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
}


def _move_to_next_line(string, s, line_width, prefix):
"""
Move string to next line if it doesn't fit in the current line.
"""
bottom_len = len(s) - (s.rfind("\n") + 1)
next_line = bottom_len + len(string) + 1 > line_width
string = ",\n" + " " * len(prefix) + string if next_line else ", " + string

return string


def _options_dict(
linewidth=None,
edgeitems=None,
Expand Down Expand Up @@ -463,16 +474,18 @@ def usm_ndarray_repr(
suffix=suffix,
)

if show_dtype:
dtype_str = "dtype={}".format(x.dtype.name)
bottom_len = len(s) - (s.rfind("\n") + 1)
next_line = bottom_len + len(dtype_str) + 1 > line_width
dtype_str = (
",\n" + " " * len(prefix) + dtype_str
if next_line
else ", " + dtype_str
)
if show_dtype or x.size == 0:
dtype_str = f"dtype={x.dtype.name}"
dtype_str = _move_to_next_line(dtype_str, s, line_width, prefix)
else:
dtype_str = ""

return prefix + s + dtype_str + suffix
options = get_print_options()
threshold = options["threshold"]
if x.size == 0 and x.shape != (0,) or x.size > threshold:
shape_str = f"shape={x.shape}"
shape_str = _move_to_next_line(shape_str, s, line_width, prefix)
else:
shape_str = ""

return prefix + s + shape_str + dtype_str + suffix
18 changes: 11 additions & 7 deletions dpctl/tests/test_usm_ndarray_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,7 @@ def test_print_repr(self):
)

x = dpt.arange(4, dtype="i4", sycl_queue=q)
x.sycl_queue.wait()
r = repr(x)
assert r == "usm_ndarray([0, 1, 2, 3], dtype=int32)"
assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)"

dpt.set_print_options(linewidth=1)
np.testing.assert_equal(
Expand All @@ -296,30 +294,35 @@ def test_print_repr(self):
"\n dtype=int32)",
)

# zero-size array
dpt.set_print_options(linewidth=75)
x = dpt.ones((9, 0), dtype="i4", sycl_queue=q)
assert repr(x) == "usm_ndarray([], shape=(9, 0), dtype=int32)"

def test_print_repr_abbreviated(self):
q = get_queue_or_skip()

dpt.set_print_options(threshold=0, edgeitems=1)
x = dpt.arange(9, dtype="int64", sycl_queue=q)
assert repr(x) == "usm_ndarray([0, ..., 8])"
assert repr(x) == "usm_ndarray([0, ..., 8], shape=(9,))"

y = dpt.asarray(x, dtype="i4", copy=True)
assert repr(y) == "usm_ndarray([0, ..., 8], dtype=int32)"
assert repr(y) == "usm_ndarray([0, ..., 8], shape=(9,), dtype=int32)"

x = dpt.reshape(x, (3, 3))
np.testing.assert_equal(
repr(x),
"usm_ndarray([[0, ..., 2],"
"\n ...,"
"\n [6, ..., 8]])",
"\n [6, ..., 8]], shape=(3, 3))",
)

y = dpt.reshape(y, (3, 3))
np.testing.assert_equal(
repr(y),
"usm_ndarray([[0, ..., 2],"
"\n ...,"
"\n [6, ..., 8]], dtype=int32)",
"\n [6, ..., 8]], shape=(3, 3), dtype=int32)",
)

dpt.set_print_options(linewidth=1)
Expand All @@ -332,6 +335,7 @@ def test_print_repr_abbreviated(self):
"\n [6,"
"\n ...,"
"\n 8]],"
"\n shape=(3, 3),"
"\n dtype=int32)",
)

Expand Down
Loading