1616from numba_dpex import spirv_kernel_target
1717from numba_dpex .core import config
1818from numba_dpex .core .types import Array
19+ from numba_dpex .kernel_api import AddressSpace as address_space
1920from numba_dpex .kernel_api_impl .spirv .codegen import SPIR_DATA_LAYOUT
2021from numba_dpex .ocl .atomics import atomic_helper
21- from numba_dpex .utils import address_space
2222
2323from . import stubs
2424from ._declare_function import _declare_function
@@ -301,7 +301,7 @@ def dpex_private_array_integer(context, builder, sig, args):
301301 shape = (length ,),
302302 dtype = dtype ,
303303 symbol_name = "_dpex_pmem" ,
304- addrspace = address_space .PRIVATE ,
304+ addrspace = address_space .PRIVATE . value ,
305305 )
306306
307307
@@ -316,7 +316,7 @@ def dpex_private_array_tuple(context, builder, sig, args):
316316 shape = shape ,
317317 dtype = dtype ,
318318 symbol_name = "_dpex_pmem" ,
319- addrspace = address_space .PRIVATE ,
319+ addrspace = address_space .PRIVATE . value ,
320320 )
321321
322322
@@ -330,7 +330,7 @@ def dpex_local_array_integer(context, builder, sig, args):
330330 shape = (length ,),
331331 dtype = dtype ,
332332 symbol_name = "_dpex_lmem" ,
333- addrspace = address_space .LOCAL ,
333+ addrspace = address_space .LOCAL . value ,
334334 )
335335
336336
@@ -345,7 +345,7 @@ def dpex_local_array_tuple(context, builder, sig, args):
345345 shape = shape ,
346346 dtype = dtype ,
347347 symbol_name = "_dpex_lmem" ,
348- addrspace = address_space .LOCAL ,
348+ addrspace = address_space .LOCAL . value ,
349349 )
350350
351351
@@ -358,7 +358,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
358358 lldtype = context .get_data_type (dtype )
359359 laryty = llvmir .ArrayType (lldtype , elemcount )
360360
361- if addrspace == address_space .LOCAL :
361+ if addrspace == address_space .LOCAL . value :
362362 lmod = builder .module
363363
364364 # Create global variable in the requested address-space
@@ -374,7 +374,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
374374 if dtype not in types .number_domain :
375375 raise TypeError ("unsupported type: %s" % dtype )
376376
377- elif addrspace == address_space .PRIVATE :
377+ elif addrspace == address_space .PRIVATE . value :
378378 gvmem = cgutils .alloca_once (builder , laryty , name = symbol_name )
379379 else :
380380 raise NotImplementedError ("addrspace {addrspace}" .format (** locals ()))
@@ -397,7 +397,7 @@ def _make_array(
397397 dtype ,
398398 shape ,
399399 layout = "C" ,
400- addrspace = address_space .GENERIC ,
400+ addrspace = address_space .GENERIC . value ,
401401):
402402 ndim = len (shape )
403403 # Create array object
0 commit comments