Skip to content

Commit 6120680

Browse files
author
Diptorup Deb
authored
Merge pull request #326 from IntelPython/feature/switch_to_new_dpex_kernel
Switch to experimental dpex kernel
2 parents d8431f4 + 54d5aab commit 6120680

File tree

9 files changed

+146
-76
lines changed

9 files changed

+146
-76
lines changed

dpbench/benchmarks/black_scholes/black_scholes_numba_dpex_k.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44

55
from math import erf, exp, log, sqrt
66

7-
import numba_dpex as dpex
7+
import numba_dpex.experimental as dpex
8+
from numba_dpex import kernel_api as kapi
89

910

1011
@dpex.kernel
11-
def _black_scholes_kernel(nopt, price, strike, t, rate, volatility, call, put):
12+
def _black_scholes_kernel(
13+
item: kapi.Item, nopt, price, strike, t, rate, volatility, call, put
14+
):
1215
dtype = price.dtype
1316
mr = -rate
1417
sig_sig_two = volatility * volatility * dtype.type(2)
1518

16-
i = dpex.get_global_id(0)
19+
i = item.get_id(0)
1720

1821
P = price[i]
1922
S = strike[i]
@@ -40,6 +43,15 @@ def _black_scholes_kernel(nopt, price, strike, t, rate, volatility, call, put):
4043

4144

4245
def black_scholes(nopt, price, strike, t, rate, volatility, call, put):
43-
_black_scholes_kernel[dpex.Range(nopt)](
44-
nopt, price, strike, t, rate, volatility, call, put
46+
dpex.call_kernel(
47+
_black_scholes_kernel,
48+
kapi.Range(nopt),
49+
nopt,
50+
price,
51+
strike,
52+
t,
53+
rate,
54+
volatility,
55+
call,
56+
put,
4557
)

dpbench/benchmarks/dbscan/dbscan_numba_dpex_k.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import dpnp as np
66
import numba as nb
7-
import numba_dpex as dpex
7+
import numba_dpex.experimental as dpex
88
import numpy
9+
from numba_dpex import kernel_api as kapi
910

1011
NOISE = -1
1112
UNDEFINED = -2
@@ -50,8 +51,10 @@ def _queue_empty(head, tail):
5051

5152

5253
@dpex.kernel
53-
def get_neighborhood(n, dim, data, eps, ind_lst, sz_lst, block_size, nblocks):
54-
i = dpex.get_global_id(0)
54+
def get_neighborhood(
55+
item: kapi.Item, n, dim, data, eps, ind_lst, sz_lst, block_size, nblocks
56+
):
57+
i = item.get_id(0)
5558

5659
start = i * block_size
5760
stop = n if i + 1 == nblocks else start + block_size
@@ -130,7 +133,9 @@ def dbscan(n_samples, n_features, data, eps, min_pts):
130133
)
131134
sizes = np.zeros_like(data, shape=n_samples, dtype=np.int64)
132135

133-
get_neighborhood[dpex.Range(n_samples)](
136+
dpex.call_kernel(
137+
get_neighborhood,
138+
kapi.Range(n_samples),
134139
n_samples,
135140
n_features,
136141
data,

dpbench/benchmarks/gpairs/gpairs_numba_dpex_k.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import numba_dpex as dpex
6+
import numba_dpex.experimental as dpexexp
7+
from numba_dpex import kernel_api as kapi
68

79
# This implementation is numba dpex kernel version with atomics.
810

911

10-
@dpex.kernel
12+
@dpexexp.kernel
1113
def count_weighted_pairs_3d_intel_no_slm_ker(
14+
nd_item: kapi.NdItem,
1215
n,
1316
nbins,
1417
slm_hist_size,
@@ -25,14 +28,14 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
2528
result,
2629
):
2730
dtype = x0.dtype
28-
lid0 = dpex.get_local_id(0)
29-
gr0 = dpex.get_group_id(0)
31+
lid0 = nd_item.get_local_id(0)
32+
gr0 = nd_item.get_group().get_group_id(0)
3033

31-
lid1 = dpex.get_local_id(1)
32-
gr1 = dpex.get_group_id(1)
34+
lid1 = nd_item.get_local_id(1)
35+
gr1 = nd_item.get_group().get_group_id(1)
3336

34-
lws0 = dpex.get_local_size(0)
35-
lws1 = dpex.get_local_size(1)
37+
lws0 = nd_item.get_local_range(0)
38+
lws1 = nd_item.get_local_range(1)
3639

3740
n_wi = 20
3841

@@ -107,7 +110,8 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
107110

108111
pk = k
109112
for p in range(private_hist_size):
110-
dpex.atomic.add(result, pk, private_hist[p])
113+
result_aref = kapi.AtomicRef(result, index=pk)
114+
result_aref.fetch_add(private_hist[p])
111115
pk += 1
112116

113117

@@ -147,7 +151,9 @@ def gpairs(
147151
ceiling_quotient(nbins, private_hist_size) * private_hist_size
148152
)
149153

150-
count_weighted_pairs_3d_intel_no_slm_ker[dpex.NdRange(gwsRange, lwsRange)](
154+
dpexexp.call_kernel(
155+
count_weighted_pairs_3d_intel_no_slm_ker,
156+
kapi.NdRange(dpex.Range(*gwsRange), dpex.Range(*lwsRange)),
151157
nopt,
152158
nbins,
153159
slm_hist_size,

dpbench/benchmarks/kmeans/kmeans_numba_dpex_k.py

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from math import sqrt
77

88
import numba_dpex as dpex
9-
import numpy
9+
import numba_dpex.experimental as dpexexp
1010
from dpctl import tensor as dpt
11-
from numba_dpex import NdRange
11+
from numba_dpex import kernel_api as kapi
1212

1313

1414
def DivUp(numerator, denominator):
@@ -25,9 +25,15 @@ def getGroupByCluster( # noqa: C901
2525
):
2626
local_copies = min(4, max(1, DivUp(local_size_, num_centroids)))
2727

28-
@dpex.kernel
28+
@dpexexp.kernel
2929
def groupByCluster(
30-
arrayP, arrayPcluster, arrayC, NewCentroids, NewCount, last
30+
nd_item: kapi.NdItem,
31+
arrayP,
32+
arrayPcluster,
33+
arrayC,
34+
NewCentroids,
35+
NewCount,
36+
last,
3137
):
3238
numpoints = arrayP.shape[0]
3339
localCentroids = dpex.local.array((dims, num_centroids), dtype=dtyp)
@@ -38,9 +44,9 @@ def groupByCluster(
3844
(local_copies, num_centroids), dtype=dpt.int32
3945
)
4046

41-
grid = dpex.get_group_id(0)
42-
lid = dpex.get_local_id(0)
43-
local_size = dpex.get_local_size(0)
47+
grid = nd_item.get_group().get_group_id(0)
48+
lid = nd_item.get_local_id(0)
49+
local_size = nd_item.get_local_range(0)
4450

4551
for i in range(lid, num_centroids * dims, local_size):
4652
localCentroids[i % dims, i // dims] = arrayC[i // dims, i % dims]
@@ -51,7 +57,7 @@ def groupByCluster(
5157
for lc in range(local_copies):
5258
localNewCount[lc, c] = 0
5359

54-
dpex.barrier(dpex.LOCAL_MEM_FENCE)
60+
kapi.group_barrier(nd_item.get_group())
5561

5662
for i in range(WorkPI):
5763
point_id = grid * WorkPI * local_size + i * local_size + lid
@@ -73,44 +79,59 @@ def groupByCluster(
7379

7480
lc = lid % local_copies
7581
for d in range(dims):
76-
dpex.atomic.add(
77-
localNewCentroids, (lc, d, nearest_centroid), localP[d]
82+
localNewCentroids_aref = kapi.AtomicRef(
83+
localNewCentroids,
84+
index=(lc, d, nearest_centroid),
85+
address_space=kapi.AddressSpace.LOCAL,
7886
)
87+
localNewCentroids_aref.fetch_add(localP[d])
7988

80-
dpex.atomic.add(localNewCount, (lc, nearest_centroid), 1)
89+
localNewCount_aref = kapi.AtomicRef(
90+
localNewCount,
91+
index=(lc, nearest_centroid),
92+
address_space=kapi.AddressSpace.LOCAL,
93+
)
94+
localNewCount_aref.fetch_add(1)
8195

8296
if last:
8397
arrayPcluster[point_id] = nearest_centroid
8498

85-
dpex.barrier(dpex.LOCAL_MEM_FENCE)
99+
kapi.group_barrier(nd_item.get_group())
86100

87101
for i in range(lid, num_centroids * dims, local_size):
88102
local_centroid_d = dtyp.type(0)
89103
for lc in range(local_copies):
90104
local_centroid_d += localNewCentroids[lc, i % dims, i // dims]
91105

92-
dpex.atomic.add(
93-
NewCentroids,
94-
(i // dims, i % dims),
95-
local_centroid_d,
106+
NewCentroids_aref = kapi.AtomicRef(
107+
NewCentroids, index=(i // dims, i % dims)
96108
)
109+
NewCentroids_aref.fetch_add(local_centroid_d)
97110

98111
for c in range(lid, num_centroids, local_size):
99112
local_centroid_npoints = dpt.int32.type(0)
100113
for lc in range(local_copies):
101114
local_centroid_npoints += localNewCount[lc, c]
102115

103-
dpex.atomic.add(NewCount, c, local_centroid_npoints)
116+
NewCount_aref = kapi.AtomicRef(NewCount, index=c)
117+
NewCount_aref.fetch_add(local_centroid_npoints)
104118

105119
return groupByCluster
106120

107121

108122
@lru_cache(maxsize=1)
109123
def getUpdateCentroids(dims, num_centroids, dtyp, local_size_):
110-
@dpex.kernel
111-
def updateCentroids(diff, arrayC, arrayCnumpoint, NewCentroids, NewCount):
112-
lid = dpex.get_local_id(0)
113-
local_size = dpex.get_local_size(0)
124+
@dpexexp.kernel
125+
def updateCentroids(
126+
nd_item: kapi.NdItem,
127+
diff,
128+
arrayC,
129+
arrayCnumpoint,
130+
NewCentroids,
131+
NewCount,
132+
):
133+
lid = nd_item.get_local_id(0)
134+
local_size = nd_item.get_local_range(0)
114135

115136
local_distance = dpex.local.array(local_size_, dtype=dtyp)
116137

@@ -134,7 +155,7 @@ def updateCentroids(diff, arrayC, arrayCnumpoint, NewCentroids, NewCount):
134155
max_distance = max(max_distance, distance)
135156
local_distance[c] = max_distance
136157

137-
dpex.barrier(dpex.LOCAL_MEM_FENCE)
158+
kapi.group_barrier(nd_item.get_group())
138159

139160
if lid == 0:
140161
for c in range(local_size):
@@ -147,19 +168,19 @@ def updateCentroids(diff, arrayC, arrayCnumpoint, NewCentroids, NewCount):
147168

148169
@lru_cache(maxsize=1)
149170
def getUpdateLabels(dims, num_centroids, dtyp, WorkPI):
150-
@dpex.kernel
151-
def updateLabels(arrayP, arrayPcluster, arrayC):
171+
@dpexexp.kernel
172+
def updateLabels(nd_item: kapi.NdItem, arrayP, arrayPcluster, arrayC):
152173
numpoints = arrayP.shape[0]
153174
localCentroids = dpex.local.array((dims, num_centroids), dtype=dtyp)
154175

155-
grid = dpex.get_group_id(0)
156-
lid = dpex.get_local_id(0)
157-
local_size = dpex.get_local_size(0)
176+
grid = nd_item.get_group().get_group_id(0)
177+
lid = nd_item.get_local_id(0)
178+
local_size = nd_item.get_local_range(0)
158179

159180
for i in range(lid, num_centroids * dims, local_size):
160181
localCentroids[i % dims, i // dims] = arrayC[i // dims, i % dims]
161182

162-
dpex.barrier(dpex.LOCAL_MEM_FENCE)
183+
kapi.group_barrier(nd_item.get_group())
163184

164185
for i in range(WorkPI):
165186
point_id = grid * WorkPI * local_size + i * local_size + lid
@@ -224,19 +245,36 @@ def kmeans_kernel(
224245
for i in range(niters):
225246
last = i == (niters - 1)
226247
if diff_host < tolerance:
227-
updateLabels[NdRange((global_size,), (local_size,))](
228-
arrayP, arrayPcluster, arrayC
248+
dpexexp.call_kernel(
249+
updateLabels,
250+
kapi.NdRange((global_size,), (local_size,)),
251+
arrayP,
252+
arrayPcluster,
253+
arrayC,
229254
)
230255
break
231256

232-
groupByCluster[NdRange((global_size,), (local_size,))](
233-
arrayP, arrayPcluster, arrayC, NewCentroids, NewCount, last
257+
dpexexp.call_kernel(
258+
groupByCluster,
259+
kapi.NdRange((global_size,), (local_size,)),
260+
arrayP,
261+
arrayPcluster,
262+
arrayC,
263+
NewCentroids,
264+
NewCount,
265+
last,
234266
)
235267

236268
update_centroid_size = min(num_centroids, local_size)
237-
updateCentroids[
238-
NdRange((update_centroid_size,), (update_centroid_size,))
239-
](diff, arrayC, arrayCnumpoint, NewCentroids, NewCount)
269+
dpexexp.call_kernel(
270+
updateCentroids,
271+
kapi.NdRange((update_centroid_size,), (update_centroid_size,)),
272+
diff,
273+
arrayC,
274+
arrayCnumpoint,
275+
NewCentroids,
276+
NewCount,
277+
)
240278
diff_host = dpt.asnumpy(diff)[0]
241279

242280

dpbench/benchmarks/knn/knn_numba_dpex_k.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from math import sqrt
66

77
import numba_dpex as dpex
8+
import numba_dpex.experimental as dpexexp
89
import numpy as np
10+
from numba_dpex import kernel_api as kapi
911

1012

11-
@dpex.kernel
13+
@dpexexp.kernel
1214
def _knn_kernel( # noqa: C901: TODO: can we simplify logic?
15+
item: kapi.Item,
1316
train,
1417
train_labels,
1518
test,
@@ -21,7 +24,7 @@ def _knn_kernel( # noqa: C901: TODO: can we simplify logic?
2124
data_dim,
2225
):
2326
dtype = train.dtype
24-
i = dpex.get_global_id(0)
27+
i = item.get_id(0)
2528
# here k has to be 5 in order to match with numpy
2629
queue_neighbors = dpex.private.array(shape=(5, 2), dtype=dtype)
2730

@@ -106,7 +109,9 @@ def knn(
106109
votes_to_classes,
107110
data_dim,
108111
):
109-
_knn_kernel[dpex.Range(test_size)](
112+
dpexexp.call_kernel(
113+
_knn_kernel,
114+
kapi.Range(test_size),
110115
x_train,
111116
y_train,
112117
x_test,

dpbench/benchmarks/l2_norm/l2_norm_numba_dpex_k.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
import math
66

7-
import numba_dpex as dpex
7+
import numba_dpex.experimental as dpex
8+
from numba_dpex import kernel_api as kapi
89

910

1011
@dpex.kernel
11-
def l2_norm_kernel(a, d):
12-
i = dpex.get_global_id(0)
12+
def l2_norm_kernel(item: kapi.Item, a, d):
13+
i = item.get_id(0)
1314
a_rows = a.shape[1]
1415
d[i] = 0.0
1516
for k in range(a_rows):
@@ -18,4 +19,4 @@ def l2_norm_kernel(a, d):
1819

1920

2021
def l2_norm(a, d):
21-
l2_norm_kernel[dpex.Range(a.shape[0])](a, d)
22+
dpex.call_kernel(l2_norm_kernel, kapi.Range(a.shape[0]), a, d)

0 commit comments

Comments
 (0)