66from math import sqrt
77
88import numba_dpex as dpex
9- import numpy
9+ import numba_dpex . experimental as dpexexp
1010from dpctl import tensor as dpt
11- from numba_dpex import NdRange
11+ from numba_dpex import kernel_api as kapi
1212
1313
1414def DivUp (numerator , denominator ):
@@ -25,9 +25,15 @@ def getGroupByCluster( # noqa: C901
2525):
2626 local_copies = min (4 , max (1 , DivUp (local_size_ , num_centroids )))
2727
28- @dpex .kernel
28+ @dpexexp .kernel
2929 def groupByCluster (
30- arrayP , arrayPcluster , arrayC , NewCentroids , NewCount , last
30+ nd_item : kapi .NdItem ,
31+ arrayP ,
32+ arrayPcluster ,
33+ arrayC ,
34+ NewCentroids ,
35+ NewCount ,
36+ last ,
3137 ):
3238 numpoints = arrayP .shape [0 ]
3339 localCentroids = dpex .local .array ((dims , num_centroids ), dtype = dtyp )
@@ -38,9 +44,9 @@ def groupByCluster(
3844 (local_copies , num_centroids ), dtype = dpt .int32
3945 )
4046
41- grid = dpex .get_group_id (0 )
42- lid = dpex .get_local_id (0 )
43- local_size = dpex . get_local_size (0 )
47+ grid = nd_item . get_group () .get_group_id (0 )
48+ lid = nd_item .get_local_id (0 )
49+ local_size = nd_item . get_local_range (0 )
4450
4551 for i in range (lid , num_centroids * dims , local_size ):
4652 localCentroids [i % dims , i // dims ] = arrayC [i // dims , i % dims ]
@@ -51,7 +57,7 @@ def groupByCluster(
5157 for lc in range (local_copies ):
5258 localNewCount [lc , c ] = 0
5359
54- dpex . barrier ( dpex . LOCAL_MEM_FENCE )
60+ kapi . group_barrier ( nd_item . get_group () )
5561
5662 for i in range (WorkPI ):
5763 point_id = grid * WorkPI * local_size + i * local_size + lid
@@ -73,44 +79,59 @@ def groupByCluster(
7379
7480 lc = lid % local_copies
7581 for d in range (dims ):
76- dpex .atomic .add (
77- localNewCentroids , (lc , d , nearest_centroid ), localP [d ]
82+ localNewCentroids_aref = kapi .AtomicRef (
83+ localNewCentroids ,
84+ index = (lc , d , nearest_centroid ),
85+ address_space = kapi .AddressSpace .LOCAL ,
7886 )
87+ localNewCentroids_aref .fetch_add (localP [d ])
7988
80- dpex .atomic .add (localNewCount , (lc , nearest_centroid ), 1 )
89+ localNewCount_aref = kapi .AtomicRef (
90+ localNewCount ,
91+ index = (lc , nearest_centroid ),
92+ address_space = kapi .AddressSpace .LOCAL ,
93+ )
94+ localNewCount_aref .fetch_add (1 )
8195
8296 if last :
8397 arrayPcluster [point_id ] = nearest_centroid
8498
85- dpex . barrier ( dpex . LOCAL_MEM_FENCE )
99+ kapi . group_barrier ( nd_item . get_group () )
86100
87101 for i in range (lid , num_centroids * dims , local_size ):
88102 local_centroid_d = dtyp .type (0 )
89103 for lc in range (local_copies ):
90104 local_centroid_d += localNewCentroids [lc , i % dims , i // dims ]
91105
92- dpex .atomic .add (
93- NewCentroids ,
94- (i // dims , i % dims ),
95- local_centroid_d ,
106+ NewCentroids_aref = kapi .AtomicRef (
107+ NewCentroids , index = (i // dims , i % dims )
96108 )
109+ NewCentroids_aref .fetch_add (local_centroid_d )
97110
98111 for c in range (lid , num_centroids , local_size ):
99112 local_centroid_npoints = dpt .int32 .type (0 )
100113 for lc in range (local_copies ):
101114 local_centroid_npoints += localNewCount [lc , c ]
102115
103- dpex .atomic .add (NewCount , c , local_centroid_npoints )
116+ NewCount_aref = kapi .AtomicRef (NewCount , index = c )
117+ NewCount_aref .fetch_add (local_centroid_npoints )
104118
105119 return groupByCluster
106120
107121
108122@lru_cache (maxsize = 1 )
109123def getUpdateCentroids (dims , num_centroids , dtyp , local_size_ ):
110- @dpex .kernel
111- def updateCentroids (diff , arrayC , arrayCnumpoint , NewCentroids , NewCount ):
112- lid = dpex .get_local_id (0 )
113- local_size = dpex .get_local_size (0 )
124+ @dpexexp .kernel
125+ def updateCentroids (
126+ nd_item : kapi .NdItem ,
127+ diff ,
128+ arrayC ,
129+ arrayCnumpoint ,
130+ NewCentroids ,
131+ NewCount ,
132+ ):
133+ lid = nd_item .get_local_id (0 )
134+ local_size = nd_item .get_local_range (0 )
114135
115136 local_distance = dpex .local .array (local_size_ , dtype = dtyp )
116137
@@ -134,7 +155,7 @@ def updateCentroids(diff, arrayC, arrayCnumpoint, NewCentroids, NewCount):
134155 max_distance = max (max_distance , distance )
135156 local_distance [c ] = max_distance
136157
137- dpex . barrier ( dpex . LOCAL_MEM_FENCE )
158+ kapi . group_barrier ( nd_item . get_group () )
138159
139160 if lid == 0 :
140161 for c in range (local_size ):
@@ -147,19 +168,19 @@ def updateCentroids(diff, arrayC, arrayCnumpoint, NewCentroids, NewCount):
147168
148169@lru_cache (maxsize = 1 )
149170def getUpdateLabels (dims , num_centroids , dtyp , WorkPI ):
150- @dpex .kernel
151- def updateLabels (arrayP , arrayPcluster , arrayC ):
171+ @dpexexp .kernel
172+ def updateLabels (nd_item : kapi . NdItem , arrayP , arrayPcluster , arrayC ):
152173 numpoints = arrayP .shape [0 ]
153174 localCentroids = dpex .local .array ((dims , num_centroids ), dtype = dtyp )
154175
155- grid = dpex .get_group_id (0 )
156- lid = dpex .get_local_id (0 )
157- local_size = dpex . get_local_size (0 )
176+ grid = nd_item . get_group () .get_group_id (0 )
177+ lid = nd_item .get_local_id (0 )
178+ local_size = nd_item . get_local_range (0 )
158179
159180 for i in range (lid , num_centroids * dims , local_size ):
160181 localCentroids [i % dims , i // dims ] = arrayC [i // dims , i % dims ]
161182
162- dpex . barrier ( dpex . LOCAL_MEM_FENCE )
183+ kapi . group_barrier ( nd_item . get_group () )
163184
164185 for i in range (WorkPI ):
165186 point_id = grid * WorkPI * local_size + i * local_size + lid
@@ -224,19 +245,36 @@ def kmeans_kernel(
224245 for i in range (niters ):
225246 last = i == (niters - 1 )
226247 if diff_host < tolerance :
227- updateLabels [NdRange ((global_size ,), (local_size ,))](
228- arrayP , arrayPcluster , arrayC
248+ dpexexp .call_kernel (
249+ updateLabels ,
250+ kapi .NdRange ((global_size ,), (local_size ,)),
251+ arrayP ,
252+ arrayPcluster ,
253+ arrayC ,
229254 )
230255 break
231256
232- groupByCluster [NdRange ((global_size ,), (local_size ,))](
233- arrayP , arrayPcluster , arrayC , NewCentroids , NewCount , last
257+ dpexexp .call_kernel (
258+ groupByCluster ,
259+ kapi .NdRange ((global_size ,), (local_size ,)),
260+ arrayP ,
261+ arrayPcluster ,
262+ arrayC ,
263+ NewCentroids ,
264+ NewCount ,
265+ last ,
234266 )
235267
236268 update_centroid_size = min (num_centroids , local_size )
237- updateCentroids [
238- NdRange ((update_centroid_size ,), (update_centroid_size ,))
239- ](diff , arrayC , arrayCnumpoint , NewCentroids , NewCount )
269+ dpexexp .call_kernel (
270+ updateCentroids ,
271+ kapi .NdRange ((update_centroid_size ,), (update_centroid_size ,)),
272+ diff ,
273+ arrayC ,
274+ arrayCnumpoint ,
275+ NewCentroids ,
276+ NewCount ,
277+ )
240278 diff_host = dpt .asnumpy (diff )[0 ]
241279
242280
0 commit comments