Skip to content

Commit 587f099

Browse files
author
Vahid Tavanashad
committed
add support for ndmin for dpnp.array
1 parent 6ba840a commit 587f099

File tree

4 files changed

+217
-201
lines changed

4 files changed

+217
-201
lines changed

dpnp/dpnp_iface_arraycreation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def array(
319319
order : {"C", "F", "A", "K"}, optional
320320
Memory layout of the newly output array.
321321
Default: ``"K"``.
322+
ndmin : int, optional
323+
Specifies the minimum number of dimensions that the resulting array
324+
should have. Ones will be prepended to the shape as needed to meet
325+
this requirement.
326+
Default: ``0``.
322327
device : {None, string, SyclDevice, SyclQueue}, optional
323328
An array API concept of device where the output array is created.
324329
The `device` can be ``None`` (the default), an OneAPI filter selector
@@ -345,7 +350,6 @@ def array(
345350
Limitations
346351
-----------
347352
Parameter `subok` is supported only with default value ``False``.
348-
Parameter `ndmin` is supported only with default value ``0``.
349353
Parameter `like` is supported only with default value ``None``.
350354
Otherwise, the function raises ``NotImplementedError`` exception.
351355
@@ -399,13 +403,10 @@ def array(
399403
"""
400404

401405
dpnp.check_limitations(subok=subok, like=like)
402-
if ndmin != 0:
403-
raise NotImplementedError(
404-
"Keyword argument `ndmin` is supported only with "
405-
f"default value ``0``, but got {ndmin}"
406-
)
406+
if not isinstance(ndmin, (int, dpnp.integer)):
407+
raise TypeError(f"`ndmin` should be an integer, got {type(ndmin)}")
407408

408-
return dpnp_container.asarray(
409+
result = dpnp_container.asarray(
409410
a,
410411
dtype=dtype,
411412
copy=copy,
@@ -415,6 +416,13 @@ def array(
415416
sycl_queue=sycl_queue,
416417
)
417418

419+
res_ndim = result.ndim
420+
if res_ndim >= ndmin:
421+
return result
422+
423+
num_axes = ndmin - res_ndim
424+
return result[(dpnp.newaxis,) * num_axes + (slice(None),)]
425+
418426

419427
def asanyarray(
420428
a,

tests/test_arraycreation.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_allclose,
1010
assert_array_equal,
1111
assert_equal,
12+
assert_raises,
1213
)
1314

1415
import dpnp
@@ -21,6 +22,40 @@
2122
)
2223

2324

25+
class TestArray:
26+
@pytest.mark.parametrize(
27+
"x", [numpy.ones((3, 4)), numpy.ones((0, 4)), [1, 2, 3], []]
28+
)
29+
@pytest.mark.parametrize("ndmin", [-5, -1, 0, 1, 2, 3, 4, 9, 21])
30+
def test_ndmin(self, x, ndmin):
31+
a = numpy.array(x, ndmin=ndmin)
32+
ia = dpnp.array(x, ndmin=ndmin)
33+
assert_array_equal(ia, a)
34+
35+
@pytest.mark.parametrize(
36+
"x",
37+
[
38+
numpy.ones((2, 3, 4, 5)),
39+
numpy.ones((3, 4)),
40+
numpy.ones((0, 4)),
41+
[1, 2, 3],
42+
[],
43+
],
44+
)
45+
@pytest.mark.parametrize("order", ["C", "F", "K", "A"])
46+
@pytest.mark.parametrize("ndmin", [1, 2, 3, 4, 9, 21])
47+
def test_ndmin_order(self, x, order, ndmin):
48+
a = numpy.array(x, order=order, ndmin=ndmin)
49+
ia = dpnp.array(x, order=order, ndmin=ndmin)
50+
assert a.flags.c_contiguous == ia.flags.c_contiguous
51+
assert a.flags.f_contiguous == ia.flags.f_contiguous
52+
assert_array_equal(ia, a)
53+
54+
def test_error(self):
55+
x = numpy.ones((3, 4))
56+
assert_raises(TypeError, dpnp.array, x, ndmin=3.0)
57+
58+
2459
class TestTrace:
2560
@pytest.mark.parametrize("a_sh", [(3, 4), (2, 2, 2)])
2661
@pytest.mark.parametrize(
@@ -140,17 +175,9 @@ def test_exception_subok(func, args):
140175
getattr(dpnp, func)(x, *args, subok=True)
141176

142177

143-
@pytest.mark.parametrize(
144-
"start", [0, -5, 10, -2.5, 9.7], ids=["0", "-5", "10", "-2.5", "9.7"]
145-
)
146-
@pytest.mark.parametrize(
147-
"stop",
148-
[None, 10, -2, 20.5, 1000],
149-
ids=["None", "10", "-2", "20.5", "10**5"],
150-
)
151-
@pytest.mark.parametrize(
152-
"step", [None, 1, 2.7, -1.6, 100], ids=["None", "1", "2.7", "-1.6", "100"]
153-
)
178+
@pytest.mark.parametrize("start", [0, -5, 10, -2.5, 9.7])
179+
@pytest.mark.parametrize("stop", [None, 10, -2, 20.5, 1000])
180+
@pytest.mark.parametrize("step", [None, 1, 2.7, -1.6, 100])
154181
@pytest.mark.parametrize(
155182
"dtype", get_all_dtypes(no_bool=True, no_float16=False)
156183
)
@@ -188,11 +215,7 @@ def test_arange(start, stop, step, dtype):
188215

189216

190217
@pytest.mark.parametrize("func", ["diag", "diagflat"])
191-
@pytest.mark.parametrize(
192-
"k",
193-
[-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6],
194-
ids=["-6", "-5", "-4", "-3", "-2", "-1", "0", "1", "2", "3", "4", "5", "6"],
195-
)
218+
@pytest.mark.parametrize("k", [-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
196219
@pytest.mark.parametrize(
197220
"v",
198221
[
@@ -251,17 +274,11 @@ def test_diag_diagflat_seq(func, seq):
251274
assert_array_equal(expected, result)
252275

253276

254-
@pytest.mark.parametrize("N", [0, 1, 2, 3], ids=["0", "1", "2", "3"])
255-
@pytest.mark.parametrize(
256-
"M", [None, 0, 1, 2, 3], ids=["None", "0", "1", "2", "3"]
257-
)
258-
@pytest.mark.parametrize(
259-
"k",
260-
[-4, -3, -2, -1, 0, 1, 2, 3, 4],
261-
ids=["-4", "-3", "-2", "-1", "0", "1", "2", "3", "4"],
262-
)
277+
@pytest.mark.parametrize("N", [0, 1, 2, 3])
278+
@pytest.mark.parametrize("M", [None, 0, 1, 2, 3])
279+
@pytest.mark.parametrize("k", [-4, -3, -2, -1, 0, 1, 2, 3, 4])
263280
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
264-
@pytest.mark.parametrize("order", [None, "C", "F"], ids=["None", "C", "F"])
281+
@pytest.mark.parametrize("order", [None, "C", "F"])
265282
def test_eye(N, M, k, dtype, order):
266283
func = lambda xp: xp.eye(N, M, k=k, dtype=dtype, order=order)
267284
assert_array_equal(func(numpy), func(dpnp))
@@ -317,7 +334,7 @@ def test_fromstring(dtype):
317334
assert_array_equal(func(dpnp), func(numpy))
318335

319336

320-
@pytest.mark.parametrize("n", [0, 1, 4], ids=["0", "1", "4"])
337+
@pytest.mark.parametrize("n", [0, 1, 4])
321338
@pytest.mark.parametrize("dtype", get_all_dtypes())
322339
def test_identity(n, dtype):
323340
func = lambda xp: xp.identity(n, dtype=dtype)
@@ -340,15 +357,9 @@ def test_loadtxt(dtype):
340357
assert_array_equal(dpnp_res, np_res)
341358

342359

343-
@pytest.mark.parametrize("N", [0, 1, 2, 3, 4], ids=["0", "1", "2", "3", "4"])
344-
@pytest.mark.parametrize(
345-
"M", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"]
346-
)
347-
@pytest.mark.parametrize(
348-
"k",
349-
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
350-
ids=["-5", "-4", "-3", "-2", "-1", "0", "1", "2", "3", "4", "5"],
351-
)
360+
@pytest.mark.parametrize("N", [0, 1, 2, 3, 4])
361+
@pytest.mark.parametrize("M", [None, 0, 1, 2, 3, 4])
362+
@pytest.mark.parametrize("k", [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])
352363
@pytest.mark.parametrize("dtype", get_all_dtypes())
353364
def test_tri(N, M, k, dtype):
354365
func = lambda xp: xp.tri(N, M, k, dtype=dtype)
@@ -409,7 +420,6 @@ def test_tri_default_dtype():
409420
"[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]",
410421
],
411422
)
412-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
413423
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
414424
def test_tril(m, k, dtype):
415425
a = numpy.array(m, dtype=dtype)
@@ -463,7 +473,6 @@ def test_tril(m, k, dtype):
463473
"[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]",
464474
],
465475
)
466-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
467476
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
468477
def test_triu(m, k, dtype):
469478
a = numpy.array(m, dtype=dtype)
@@ -473,11 +482,7 @@ def test_triu(m, k, dtype):
473482
assert_array_equal(expected, result)
474483

475484

476-
@pytest.mark.parametrize(
477-
"k",
478-
[-4, -3, -2, -1, 0, 1, 2, 3, 4],
479-
ids=["-4", "-3", "-2", "-1", "0", "1", "2", "3", "4"],
480-
)
485+
@pytest.mark.parametrize("k", [-4, -3, -2, -1, 0, 1, 2, 3, 4])
481486
def test_triu_size_null(k):
482487
a = numpy.ones(shape=(1, 2, 0))
483488
ia = dpnp.array(a)
@@ -492,8 +497,8 @@ def test_triu_size_null(k):
492497
ids=["[1, 2, 3, 4]", "[]", "[0, 3, 5]"],
493498
)
494499
@pytest.mark.parametrize("dtype", get_all_dtypes())
495-
@pytest.mark.parametrize("n", [0, 1, 4, None], ids=["0", "1", "4", "None"])
496-
@pytest.mark.parametrize("increase", [True, False], ids=["True", "False"])
500+
@pytest.mark.parametrize("n", [0, 1, 4, None])
501+
@pytest.mark.parametrize("increase", [True, False])
497502
def test_vander(array, dtype, n, increase):
498503
if dtype in [dpnp.complex64, dpnp.complex128] and array == [0, 3, 5]:
499504
pytest.skip(
@@ -537,7 +542,7 @@ def test_vander_seq(sequence):
537542
"fill_value", [1.5, 2, 1.5 + 0.0j], ids=["1.5", "2", "1.5+0.j"]
538543
)
539544
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
540-
@pytest.mark.parametrize("order", [None, "C", "F"], ids=["None", "C", "F"])
545+
@pytest.mark.parametrize("order", [None, "C", "F"])
541546
def test_full(shape, fill_value, dtype, order):
542547
func = lambda xp: xp.full(shape, fill_value, dtype=dtype, order=order)
543548
assert_array_equal(func(numpy), func(dpnp))
@@ -562,8 +567,8 @@ def test_full_like(array, fill_value, dtype, order):
562567
assert_array_equal(func(numpy, a), func(dpnp, ia))
563568

564569

565-
@pytest.mark.parametrize("order1", ["F", "C"], ids=["F", "C"])
566-
@pytest.mark.parametrize("order2", ["F", "C"], ids=["F", "C"])
570+
@pytest.mark.parametrize("order1", ["F", "C"])
571+
@pytest.mark.parametrize("order2", ["F", "C"])
567572
def test_full_order(order1, order2):
568573
array = numpy.array([1, 2, 3], order=order1)
569574
a = numpy.full((3, 3), array, order=order2)
@@ -600,7 +605,7 @@ def test_full_invalid_fill_value(fill_value):
600605
ids=["()", "0", "(0,)", "(2, 0, 3)", "(3, 2)"],
601606
)
602607
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
603-
@pytest.mark.parametrize("order", [None, "C", "F"], ids=["None", "C", "F"])
608+
@pytest.mark.parametrize("order", [None, "C", "F"])
604609
def test_zeros(shape, dtype, order):
605610
func = lambda xp: xp.zeros(shape, dtype=dtype, order=order)
606611
assert_array_equal(func(numpy), func(dpnp))
@@ -627,7 +632,7 @@ def test_zeros_like(array, dtype, order):
627632
ids=["()", "0", "(0,)", "(2, 0, 3)", "(3, 2)"],
628633
)
629634
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
630-
@pytest.mark.parametrize("order", [None, "C", "F"], ids=["None", "C", "F"])
635+
@pytest.mark.parametrize("order", [None, "C", "F"])
631636
def test_empty(shape, dtype, order):
632637
func = lambda xp: xp.empty(shape, dtype=dtype, order=order)
633638
assert func(numpy).shape == func(dpnp).shape
@@ -654,7 +659,7 @@ def test_empty_like(array, dtype, order):
654659
ids=["()", "0", "(0,)", "(2, 0, 3)", "(3, 2)"],
655660
)
656661
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
657-
@pytest.mark.parametrize("order", [None, "C", "F"], ids=["None", "C", "F"])
662+
@pytest.mark.parametrize("order", [None, "C", "F"])
658663
def test_ones(shape, dtype, order):
659664
func = lambda xp: xp.ones(shape, dtype=dtype, order=order)
660665
assert_array_equal(func(numpy), func(dpnp))
@@ -695,12 +700,8 @@ def test_dpctl_tensor_input(func, args):
695700
assert_array_equal(X, Y)
696701

697702

698-
@pytest.mark.parametrize(
699-
"start", [0, -5, 10, -2.5, 9.7], ids=["0", "-5", "10", "-2.5", "9.7"]
700-
)
701-
@pytest.mark.parametrize(
702-
"stop", [0, 10, -2, 20.5, 1000], ids=["0", "10", "-2", "20.5", "1000"]
703-
)
703+
@pytest.mark.parametrize("start", [0, -5, 10, -2.5, 9.7])
704+
@pytest.mark.parametrize("stop", [0, 10, -2, 20.5, 1000])
704705
@pytest.mark.parametrize(
705706
"num",
706707
[1, 5, numpy.array(10), dpnp.array(17), dpt.asarray(100)],
@@ -709,7 +710,7 @@ def test_dpctl_tensor_input(func, args):
709710
@pytest.mark.parametrize(
710711
"dtype", get_all_dtypes(no_bool=True, no_float16=False)
711712
)
712-
@pytest.mark.parametrize("retstep", [True, False], ids=["True", "False"])
713+
@pytest.mark.parametrize("retstep", [True, False])
713714
def test_linspace(start, stop, num, dtype, retstep):
714715
res_np = numpy.linspace(start, stop, num, dtype=dtype, retstep=retstep)
715716
res_dp = dpnp.linspace(start, stop, num, dtype=dtype, retstep=retstep)
@@ -803,7 +804,7 @@ def test_linspace_retstep(start, stop):
803804
ids=["[]", "[[1]]", "[[1, 2, 3], [4, 5, 6]]", "[[1, 2], [3, 4], [5, 6]]"],
804805
)
805806
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
806-
@pytest.mark.parametrize("indexing", ["ij", "xy"], ids=["ij", "xy"])
807+
@pytest.mark.parametrize("indexing", ["ij", "xy"])
807808
def test_meshgrid(arrays, dtype, indexing):
808809
func = lambda xp, xi: xp.meshgrid(*xi, indexing=indexing)
809810
a = tuple(numpy.array(array, dtype=dtype) for array in arrays)

0 commit comments

Comments
 (0)