Skip to content

Commit e02c6cb

Browse files
committed
usm_ndarray_repr inserts commas
user defined nanstr and infstr implemented
1 parent 8412394 commit e02c6cb

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

dpctl/tensor/_print.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,23 +183,12 @@ def usm_ndarray_str(
183183
if x.size > threshold:
184184
# need edge_items + 1 elements for np.array2string to abbreviate
185185
data = dpt.asnumpy(_nd_corners(x, edge_items + 1))
186-
forward_threshold = 0
186+
options["threshold"] = 0
187187
else:
188188
data = dpt.asnumpy(x)
189-
forward_threshold = threshold
190-
return np.array2string(
191-
data,
192-
max_line_width=options["linewidth"],
193-
edgeitems=options["edgeitems"],
194-
threshold=forward_threshold,
195-
precision=options["precision"],
196-
floatmode=options["floatmode"],
197-
suppress_small=options["suppress"],
198-
sign=options["sign"],
199-
separator=separator,
200-
prefix=prefix,
201-
suffix=suffix,
202-
)
189+
with np.printoptions(**options):
190+
s = np.array2string(data, separator=separator, prefix=prefix, suffix=suffix)
191+
return s
203192

204193

205194
def usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
@@ -224,6 +213,7 @@ def usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
224213
line_width=line_width,
225214
precision=precision,
226215
suppress=suppress,
216+
separator=", ",
227217
prefix=prefix,
228218
suffix=suffix,
229219
)

0 commit comments

Comments
 (0)