Skip to content

Commit 42a9dc6

Browse files
committed
Implemented dpctl.tensor.print_options
context manager
1 parent e02c6cb commit 42a9dc6

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
squeeze,
5959
stack,
6060
)
61-
from dpctl.tensor._print import get_print_options, set_print_options
61+
from dpctl.tensor._print import (
62+
get_print_options,
63+
print_options,
64+
set_print_options,
65+
)
6266
from dpctl.tensor._reshape import reshape
6367
from dpctl.tensor._usmarray import usm_ndarray
6468

@@ -132,4 +136,5 @@
132136
"meshgrid",
133137
"get_print_options",
134138
"set_print_options",
139+
"print_options",
135140
]

dpctl/tensor/_print.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import contextlib
1718
import operator
1819

1920
import numpy as np
@@ -125,6 +126,16 @@ def get_print_options():
125126
return _print_options.copy()
126127

127128

129+
@contextlib.contextmanager
130+
def print_options(*args, **kwargs):
131+
options = dpt.get_print_options()
132+
try:
133+
dpt.set_print_options(*args, **kwargs)
134+
yield dpt.get_print_options()
135+
finally:
136+
dpt.set_print_options(**options)
137+
138+
128139
def _nd_corners(x, edge_items, slices=()):
129140
axes_reduced = len(slices)
130141
if axes_reduced == x.ndim:
@@ -187,7 +198,9 @@ def usm_ndarray_str(
187198
else:
188199
data = dpt.asnumpy(x)
189200
with np.printoptions(**options):
190-
s = np.array2string(data, separator=separator, prefix=prefix, suffix=suffix)
201+
s = np.array2string(
202+
data, separator=separator, prefix=prefix, suffix=suffix
203+
)
191204
return s
192205

193206

0 commit comments

Comments
 (0)