Skip to content

Commit 1dfe86c

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent 6c981e4 commit 1dfe86c

File tree

2 files changed

+91
-223
lines changed

2 files changed

+91
-223
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 47 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def assert_sycl_queue_equal(result, expected):
7878
],
7979
)
8080
@pytest.mark.parametrize("device", valid_dev + [None], ids=dev_ids + [None])
81-
def test_array_creation(func, arg, kwargs, device):
81+
def test_array_creation_from_scratch(func, arg, kwargs, device):
8282
kwargs = dict(kwargs)
8383
kwargs["device"] = device
8484
x = getattr(dpnp, func)(*arg, **kwargs)
@@ -90,115 +90,59 @@ def test_array_creation(func, arg, kwargs, device):
9090

9191

9292
@pytest.mark.parametrize(
93-
"func, args, kwargs",
93+
"func, args",
9494
[
95-
pytest.param("copy", ["x0"], {}),
96-
pytest.param("diag", ["x0"], {}),
97-
pytest.param("empty_like", ["x0"], {}),
98-
pytest.param("full_like", ["x0"], {"fill_value": 5}),
99-
pytest.param("geomspace", ["x0[0:3]", "8", "4"], {}),
100-
pytest.param("geomspace", ["1", "x0[2:4]", "4"], {}),
101-
pytest.param("linspace", ["x0[0:2]", "8", "4"], {}),
102-
pytest.param("linspace", ["0", "x0[2:4]", "4"], {}),
103-
pytest.param("logspace", ["x0[0:2]", "8", "4"], {}),
104-
pytest.param("logspace", ["0", "x0[2:4]", "4"], {}),
105-
pytest.param("ones_like", ["x0"], {}),
106-
pytest.param("tril", ["x0.reshape((2,2))"], {}),
107-
pytest.param("triu", ["x0.reshape((2,2))"], {}),
108-
pytest.param("linspace", ["x0", "4", "4"], {}),
109-
pytest.param("linspace", ["1", "x0", "4"], {}),
110-
pytest.param("vander", ["x0"], {}),
111-
pytest.param("zeros_like", ["x0"], {}),
112-
],
113-
)
114-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
115-
def test_array_creation_follow_device(func, args, kwargs, device):
116-
x = dpnp.array([1, 2, 3, 4], device=device)
117-
dpnp_args = [eval(val, {"x0": x}) for val in args]
118-
y = getattr(dpnp, func)(*dpnp_args, **kwargs)
119-
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
120-
121-
122-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
123-
def test_array_creation_follow_device_logspace_base(device):
124-
x = dpnp.array([1, 2, 3, 4], device=device)
125-
y = dpnp.logspace(0, 8, 4, base=x[1:3])
126-
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
127-
128-
129-
@pytest.mark.parametrize(
130-
"func, args, kwargs",
131-
[
132-
pytest.param("diag", ["x0"], {}),
133-
pytest.param("diagflat", ["x0"], {}),
134-
],
135-
)
136-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
137-
def test_array_creation_follow_device_2d_array(func, args, kwargs, device):
138-
x = dpnp.arange(9, device=device).reshape(3, 3)
139-
dpnp_args = [eval(val, {"x0": x}) for val in args]
140-
y = getattr(dpnp, func)(*dpnp_args, **kwargs)
141-
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
142-
143-
144-
@pytest.mark.parametrize(
145-
"func, args, kwargs",
146-
[
147-
pytest.param("copy", ["x0"], {}),
148-
pytest.param("diag", ["x0"], {}),
149-
pytest.param("empty_like", ["x0"], {}),
150-
pytest.param("full", ["10", "x0[3]"], {}),
151-
pytest.param("full_like", ["x0"], {"fill_value": 5}),
152-
pytest.param("ones_like", ["x0"], {}),
153-
pytest.param("zeros_like", ["x0"], {}),
154-
pytest.param("linspace", ["x0", "4", "4"], {}),
155-
pytest.param("linspace", ["1", "x0", "4"], {}),
156-
pytest.param("vander", ["x0"], {}),
95+
pytest.param("copy", ["x0"]),
96+
pytest.param("diag", ["x0"]),
97+
pytest.param("diag", ["x0.reshape((2,2))"]),
98+
pytest.param("diagflat", ["x0.reshape((2,2))"]),
99+
pytest.param("empty_like", ["x0"]),
100+
pytest.param("full", ["10", "x0[3]"]),
101+
pytest.param("full_like", ["x0", "5"]),
102+
pytest.param("geomspace", ["x0[0:3]", "8", "4"]),
103+
pytest.param("geomspace", ["1", "x0[2:4]", "4"]),
104+
pytest.param("linspace", ["x0[0:2]", "8", "4"]),
105+
pytest.param("linspace", ["0", "x0[2:4]", "4"]),
106+
pytest.param("logspace", ["x0[0:2]", "8", "4"]),
107+
pytest.param("logspace", ["0", "x0[2:4]", "4"]),
108+
pytest.param("ones_like", ["x0"]),
109+
pytest.param("vander", ["x0"]),
110+
pytest.param("zeros_like", ["x0"]),
157111
],
158112
)
159113
@pytest.mark.parametrize("device_x", valid_dev, ids=dev_ids)
160114
@pytest.mark.parametrize("device_y", valid_dev, ids=dev_ids)
161-
def test_array_creation_cross_device(func, args, kwargs, device_x, device_y):
115+
def test_array_creation_from_array(func, args, device_x, device_y):
162116
if func == "linspace" and is_win_platform():
163117
pytest.skip("CPU driver experiences an instability on Windows.")
164118

165119
x = dpnp.array([1, 2, 3, 4], device=device_x)
166-
dpnp_args = [eval(val, {"x0": x}) for val in args]
120+
args = [eval(val, {"x0": x}) for val in args]
167121

168-
dpnp_kwargs = dict(kwargs)
169-
y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
122+
# follow device
123+
y = getattr(dpnp, func)(*args)
170124
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
171125

172-
dpnp_kwargs["device"] = device_y
173-
y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
174-
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
126+
# cross device
127+
# TODO: include geomspace when issue dpnp#2352 is resolved
128+
if func != "geomspace":
129+
y = getattr(dpnp, func)(*args, device=device_y)
130+
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
175131

176132

177-
@pytest.mark.parametrize(
178-
"func, args, kwargs",
179-
[
180-
pytest.param("diag", ["x0"], {}),
181-
pytest.param("diagflat", ["x0"], {}),
182-
],
183-
)
184133
@pytest.mark.parametrize("device_x", valid_dev, ids=dev_ids)
185134
@pytest.mark.parametrize("device_y", valid_dev, ids=dev_ids)
186-
def test_array_creation_cross_device_2d_array(
187-
func, args, kwargs, device_x, device_y
188-
):
189-
if func == "linspace" and is_win_platform():
190-
pytest.skip("CPU driver experiences an instability on Windows.")
191-
192-
x = dpnp.arange(9, device=device_x).reshape(3, 3)
193-
dpnp_args = [eval(val, {"x0": x}) for val in args]
135+
def test_array_creation_logspace_base(device_x, device_y):
136+
x = dpnp.array([1, 2, 3, 4], device=device_x)
194137

195-
dpnp_kwargs = dict(kwargs)
196-
y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
138+
# follow device
139+
y = dpnp.logspace(0, 8, 4, base=x[1:3])
197140
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
198141

199-
dpnp_kwargs["device"] = device_y
200-
y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
201-
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
142+
# TODO: include geomspace when issue dpnp#2353 is resolved
143+
# cross device
144+
# y = dpnp.logspace(0, 8, 4, base=x[1:3], device=device_y)
145+
# assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
202146

203147

204148
@pytest.mark.parametrize("device", valid_dev + [None], ids=dev_ids + [None])
@@ -378,9 +322,9 @@ def test_meshgrid(device):
378322
"tan", [-dpnp.pi / 2, -dpnp.pi / 4, 0.0, dpnp.pi / 4, dpnp.pi / 2]
379323
),
380324
pytest.param("tanh", [-5.0, -3.5, 0.0, 3.5, 5.0]),
381-
pytest.param(
382-
"trace", [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
383-
),
325+
pytest.param("trace", numpy.eye(3)),
326+
pytest.param("tril", numpy.ones((3, 3))),
327+
pytest.param("triu", numpy.ones((3, 3))),
384328
pytest.param("trapezoid", [1, 2, 3]),
385329
pytest.param("trim_zeros", [0, 0, 0, 1, 2, 3, 0, 2, 1, 0]),
386330
pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
@@ -437,9 +381,7 @@ def test_1in_1out(func, data, device):
437381
pytest.param("dot", [3 + 2j, 4 + 1j, 5], [1, 2 + 3j, 3]),
438382
pytest.param("extract", [False, True, True, False], [0, 1, 2, 3]),
439383
pytest.param(
440-
"float_power",
441-
[0, 1, 2, 3, 4, 5],
442-
[1.0, 2.0, 3.0, 3.0, 2.0, 1.0],
384+
"float_power", [0, 1, 2, 3, 4, 5], [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]
443385
),
444386
pytest.param(
445387
"floor_divide", [1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]
@@ -451,39 +393,23 @@ def test_1in_1out(func, data, device):
451393
[-3.0, -2.0, -1.0, 1.0, 2.0, 3.0],
452394
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
453395
),
454-
pytest.param(
455-
"gcd",
456-
[0, 1, 2, 3, 4, 5],
457-
[20, 20, 20, 20, 20, 20],
458-
),
396+
pytest.param("gcd", [0, 1, 2, 3, 4, 5], [20, 20, 20, 20, 20, 20]),
459397
pytest.param(
460398
"gradient",
461399
[1.0, 2.0, 4.0, 7.0, 11.0, 16.0],
462400
[0.0, 1.0, 1.5, 3.5, 4.0, 6.0],
463401
),
464402
pytest.param("heaviside", [-1.5, 0, 2.0], [0.5]),
465403
pytest.param(
466-
"histogram_bin_edges",
467-
[0, 0, 0, 1, 2, 3, 3, 4, 5],
468-
[1, 2],
469-
),
470-
pytest.param(
471-
"hypot", [[1.0, 2.0, 3.0, 4.0]], [[-1.0, -2.0, -4.0, -5.0]]
404+
"histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5], [1, 2]
472405
),
406+
pytest.param("hypot", [1.0, 2.0, 3.0, 4.0], [-1.0, -2.0, -4.0, -5.0]),
473407
pytest.param("inner", [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]),
474408
pytest.param("kron", [3.0, 4.0, 5.0], [1.0, 2.0]),
475-
pytest.param(
476-
"lcm",
477-
[0, 1, 2, 3, 4, 5],
478-
[20, 20, 20, 20, 20, 20],
479-
),
480-
pytest.param(
481-
"ldexp",
482-
[5, 5, 5, 5, 5],
483-
[0, 1, 2, 3, 4],
484-
),
485-
pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]),
486-
pytest.param("logaddexp2", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]),
409+
pytest.param("lcm", [0, 1, 2, 3, 4, 5], [20, 20, 20, 20, 20, 20]),
410+
pytest.param("ldexp", [5, 5, 5, 5, 5], [0, 1, 2, 3, 4]),
411+
pytest.param("logaddexp", [-1, 2, 5, 9], [4, -3, 2, -8]),
412+
pytest.param("logaddexp2", [-1, 2, 5, 9], [4, -3, 2, -8]),
487413
pytest.param(
488414
"matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]]
489415
),
@@ -613,6 +539,7 @@ def test_logic_op_2in(op, device):
613539
"func, data, scalar",
614540
[
615541
pytest.param("searchsorted", [11, 12, 13, 14, 15], 13),
542+
pytest.param("broadcast_to", numpy.ones(7), (2, 7)),
616543
],
617544
)
618545
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
@@ -1351,7 +1278,6 @@ def test_array_copy(device, func, device_param, queue_param):
13511278
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
13521279
def test_array_creation_from_dpctl(copy, device):
13531280
dpt_data = dpt.ones((3, 3), device=device)
1354-
13551281
result = dpnp.array(dpt_data, copy=copy)
13561282

13571283
assert_sycl_queue_equal(result.sycl_queue, dpt_data.sycl_queue)
@@ -1387,13 +1313,6 @@ def test_from_dlpack_with_dpt(arr_dtype, device):
13871313
assert_sycl_queue_equal(X.sycl_queue, Y.sycl_queue)
13881314

13891315

1390-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1391-
def test_broadcast_to(device):
1392-
x = dpnp.arange(5, device=device)
1393-
y = dpnp.broadcast_to(x, (3, 5))
1394-
assert_sycl_queue_equal(x.sycl_queue, y.sycl_queue)
1395-
1396-
13971316
@pytest.mark.parametrize(
13981317
"func,data1,data2",
13991318
[

0 commit comments

Comments
 (0)