Skip to content

Commit b653dfa

Browse files
author
Diptorup Deb
committed
Add unit tests
1 parent 1a72510 commit b653dfa

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpctl
6+
import dpnp
7+
8+
import numba_dpex.experimental as exp_dpex
9+
from numba_dpex import NdRange, Range, dpjit
10+
11+
12+
@exp_dpex.kernel(
13+
release_gil=False,
14+
no_compile=True,
15+
no_cpython_wrapper=True,
16+
no_cfunc_wrapper=True,
17+
)
18+
def add(a, b, c):
19+
c[0] = b[0] + a[0]
20+
21+
22+
@exp_dpex.kernel(
23+
release_gil=False,
24+
no_compile=True,
25+
no_cpython_wrapper=True,
26+
no_cfunc_wrapper=True,
27+
)
28+
def sq(a, b):
29+
a[0] = b[0] * b[0]
30+
31+
32+
def test_call_kernel_from_cpython():
33+
"""
34+
Tests if we can call a kernel function from CPython using the call_kernel
35+
dpjit function.
36+
"""
37+
38+
q = dpctl.SyclQueue()
39+
a = dpnp.ones(100, sycl_queue=q)
40+
b = dpnp.ones_like(a, sycl_queue=q)
41+
c = dpnp.zeros_like(a, sycl_queue=q)
42+
r = Range(100)
43+
ndr = NdRange(global_size=(100,), local_size=(1,))
44+
45+
exp_dpex.call_kernel(add, r, a, b, c)
46+
47+
assert c[0] == b[0] + a[0]
48+
49+
exp_dpex.call_kernel(add, ndr, a, b, c)
50+
51+
assert c[0] == b[0] + a[0]
52+
53+
54+
def test_call_kernel_from_dpjit():
55+
"""
56+
Tests if we can call a kernel function from a dpjit function using the
57+
call_kernel dpjit function.
58+
"""
59+
60+
@dpjit
61+
def range_kernel_caller(q, a, b, c):
62+
r = Range(100)
63+
exp_dpex.call_kernel(add, r, a, b, c)
64+
return c
65+
66+
@dpjit
67+
def ndrange_kernel_caller(q, a, b, c):
68+
gr = Range(100)
69+
lr = Range(1)
70+
ndr = NdRange(gr, lr)
71+
exp_dpex.call_kernel(add, ndr, a, b, c)
72+
return c
73+
74+
q = dpctl.SyclQueue()
75+
a = dpnp.ones(100, sycl_queue=q)
76+
b = dpnp.ones_like(a, sycl_queue=q)
77+
c = dpnp.zeros_like(a, sycl_queue=q)
78+
79+
range_kernel_caller(q, a, b, c)
80+
81+
assert c[0] == b[0] + a[0]
82+
83+
ndrange_kernel_caller(q, a, b, c)
84+
85+
assert c[0] == b[0] + a[0]
86+
87+
88+
def test_call_multiple_kernels():
89+
"""
90+
Tests if the call_kernel dpjit function supports calling different types of
91+
kernel with different number of arguments.
92+
"""
93+
q = dpctl.SyclQueue()
94+
a = dpnp.ones(100, sycl_queue=q)
95+
b = dpnp.ones_like(a, sycl_queue=q)
96+
c = dpnp.zeros_like(a, sycl_queue=q)
97+
r = Range(100)
98+
99+
exp_dpex.call_kernel(add, r, a, b, c)
100+
101+
assert c[0] == b[0] + a[0]
102+
103+
exp_dpex.call_kernel(sq, r, a, c)
104+
105+
assert a[0] == c[0] * c[0]

0 commit comments

Comments
 (0)