66from math import sqrt
77
88import numba_dpex as dpex
9- import numba_dpex .experimental as dpexexp
109from dpctl import tensor as dpt
1110from numba_dpex import kernel_api as kapi
1211
@@ -23,9 +22,7 @@ def Align(value, base):
2322def getGroupByCluster ( # noqa: C901
2423 dims , num_centroids , dtyp , WorkPI , local_size_
2524):
26- local_copies = min (4 , max (1 , DivUp (local_size_ , num_centroids )))
27-
28- @dpexexp .kernel
25+ @dpex .kernel
2926 def groupByCluster (
3027 nd_item : kapi .NdItem ,
3128 arrayP ,
@@ -34,15 +31,12 @@ def groupByCluster(
3431 NewCentroids ,
3532 NewCount ,
3633 last ,
34+ local_copies ,
35+ localCentroids ,
36+ localNewCentroids ,
37+ localNewCount ,
3738 ):
3839 numpoints = arrayP .shape [0 ]
39- localCentroids = dpex .local .array ((dims , num_centroids ), dtype = dtyp )
40- localNewCentroids = dpex .local .array (
41- (local_copies , dims , num_centroids ), dtype = dtyp
42- )
43- localNewCount = dpex .local .array (
44- (local_copies , num_centroids ), dtype = dpt .int32
45- )
4640
4741 grid = nd_item .get_group ().get_group_id (0 )
4842 lid = nd_item .get_local_id (0 )
@@ -121,20 +115,19 @@ def groupByCluster(
121115
122116@lru_cache (maxsize = 1 )
123117def getUpdateCentroids (dims , num_centroids , dtyp , local_size_ ):
124- @dpexexp .kernel
118+ @dpex .kernel
125119 def updateCentroids (
126120 nd_item : kapi .NdItem ,
127121 diff ,
128122 arrayC ,
129123 arrayCnumpoint ,
130124 NewCentroids ,
131125 NewCount ,
126+ local_distance ,
132127 ):
133128 lid = nd_item .get_local_id (0 )
134129 local_size = nd_item .get_local_range (0 )
135130
136- local_distance = dpex .local .array (local_size_ , dtype = dtyp )
137-
138131 max_distance = dtyp .type (0 )
139132 for c in range (lid , num_centroids , local_size ):
140133 numpoints = NewCount [c ]
@@ -168,10 +161,11 @@ def updateCentroids(
168161
169162@lru_cache (maxsize = 1 )
170163def getUpdateLabels (dims , num_centroids , dtyp , WorkPI ):
171- @dpexexp .kernel
172- def updateLabels (nd_item : kapi .NdItem , arrayP , arrayPcluster , arrayC ):
164+ @dpex .kernel
165+ def updateLabels (
166+ nd_item : kapi .NdItem , arrayP , arrayPcluster , arrayC , localCentroids
167+ ):
173168 numpoints = arrayP .shape [0 ]
174- localCentroids = dpex .local .array ((dims , num_centroids ), dtype = dtyp )
175169
176170 grid = nd_item .get_group ().get_group_id (0 )
177171 lid = nd_item .get_local_id (0 )
@@ -245,16 +239,31 @@ def kmeans_kernel(
245239 for i in range (niters ):
246240 last = i == (niters - 1 )
247241 if diff_host < tolerance :
248- dpexexp .call_kernel (
242+ localCentroids = kapi .LocalAccessor (
243+ (dims , num_centroids ), dtype = arrayP .dtype
244+ )
245+
246+ dpex .call_kernel (
249247 updateLabels ,
250248 kapi .NdRange ((global_size ,), (local_size ,)),
251249 arrayP ,
252250 arrayPcluster ,
253251 arrayC ,
252+ localCentroids ,
254253 )
255254 break
256255
257- dpexexp .call_kernel (
256+ local_copies = min (4 , max (1 , DivUp (local_size , num_centroids )))
257+ localCentroids = kapi .LocalAccessor (
258+ (dims , num_centroids ), dtype = arrayP .dtype
259+ )
260+ localNewCentroids = kapi .LocalAccessor (
261+ (local_copies , dims , num_centroids ), dtype = arrayP .dtype
262+ )
263+ localNewCount = kapi .LocalAccessor (
264+ (local_copies , num_centroids ), dtype = dpt .int64
265+ )
266+ dpex .call_kernel (
258267 groupByCluster ,
259268 kapi .NdRange ((global_size ,), (local_size ,)),
260269 arrayP ,
@@ -263,17 +272,23 @@ def kmeans_kernel(
263272 NewCentroids ,
264273 NewCount ,
265274 last ,
275+ local_copies ,
276+ localCentroids ,
277+ localNewCentroids ,
278+ localNewCount ,
266279 )
267280
281+ local_distance = kapi .LocalAccessor (local_size , dtype = arrayP .dtype )
268282 update_centroid_size = min (num_centroids , local_size )
269- dpexexp .call_kernel (
283+ dpex .call_kernel (
270284 updateCentroids ,
271285 kapi .NdRange ((update_centroid_size ,), (update_centroid_size ,)),
272286 diff ,
273287 arrayC ,
274288 arrayCnumpoint ,
275289 NewCentroids ,
276290 NewCount ,
291+ local_distance ,
277292 )
278293 diff_host = dpt .asnumpy (diff )[0 ]
279294
0 commit comments