Skip to content

Commit 77397fe

Browse files
author
Diptorup Deb
committed
Remove utils.constants and utils.array_utils
1 parent ebc6c32 commit 77397fe

File tree

14 files changed

+34
-374
lines changed

14 files changed

+34
-374
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DpctlMDLocalAccessorType,
1818
LocalAccessorType,
1919
)
20-
from numba_dpex.utils import address_space
20+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
2121

2222
from ..types import (
2323
Array,
@@ -62,7 +62,7 @@ def __init__(self, dmm, fe_type):
6262
adrsp = (
6363
fe_type.addrspace
6464
if fe_type.addrspace is not None
65-
else address_space.GLOBAL
65+
else address_space.GLOBAL.value
6666
)
6767
be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
6868
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)

numba_dpex/core/types/kernel_api/local_accessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from numba.np import numpy_support
99

1010
from numba_dpex.core.types import USMNdArray
11-
from numba_dpex.utils import address_space as AddressSpace
11+
from numba_dpex.kernel_api.memory_enums import AddressSpace
1212

1313

1414
class DpctlMDLocalAccessorType(Type):
@@ -36,14 +36,14 @@ def __init__(self, ndim, dtype):
3636

3737
type_name = (
3838
f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, "
39-
f"address_space={AddressSpace.LOCAL})"
39+
f"address_space={AddressSpace.LOCAL.value})"
4040
)
4141

4242
super().__init__(
4343
ndim=ndim,
4444
layout="C",
4545
dtype=parsed_dtype,
46-
addrspace=AddressSpace.LOCAL,
46+
addrspace=AddressSpace.LOCAL.value,
4747
name=type_name,
4848
)
4949

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from numba.np.numpy_support import from_dtype
1414

1515
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
16-
from numba_dpex.utils.constants import address_space
16+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
1717

1818

1919
class USMNdArray(Array):
@@ -30,7 +30,7 @@ def __init__(
3030
readonly=False,
3131
name=None,
3232
aligned=True,
33-
addrspace=address_space.GLOBAL,
33+
addrspace=address_space.GLOBAL.value,
3434
):
3535
if (
3636
queue is not None

numba_dpex/core/typing/typeof.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from numba.np import numpy_support
1010

1111
from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem
12+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
1213
from numba_dpex.kernel_api.ranges import NdRange, Range
13-
from numba_dpex.utils.constants import address_space
1414

1515
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
1616
from ..types.dpnp_ndarray_type import DpnpNdArray
@@ -59,7 +59,7 @@ def _array_typeof_helper(val, array_class_type):
5959
readonly=readonly,
6060
usm_type=usm_type,
6161
queue=ty_queue,
62-
addrspace=address_space.GLOBAL,
62+
addrspace=address_space.GLOBAL.value,
6363
)
6464

6565

numba_dpex/kernel_api_impl/spirv/overloads/_private_array_overloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121
from numba_dpex.core.types import USMNdArray
2222
from numba_dpex.kernel_api import PrivateArray
23+
from numba_dpex.kernel_api.memory_enums import AddressSpace
2324
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
2425
np_cfarray,
2526
require_literal,
2627
)
2728
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTypingContext
28-
from numba_dpex.utils import address_space as AddressSpace
2929

3030
from ._registry import lower
3131

@@ -50,7 +50,7 @@ def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)):
5050
dtype=_ty_parse_dtype(dtype),
5151
ndim=_ty_parse_shape(shape),
5252
layout="C",
53-
addrspace=AddressSpace.PRIVATE,
53+
addrspace=AddressSpace.PRIVATE.value,
5454
)
5555

5656
return typer

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from numba_dpex.core.types import IntEnumLiteral
2525
from numba_dpex.core.typing import dpnpdecl
2626
from numba_dpex.kernel_api.flag_enum import FlagEnum
27+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
2728
from numba_dpex.kernel_api_impl.spirv.arrayobj import populate_array
2829
from numba_dpex.ocl.mathimpl import lower_ocl_impl, sig_mapper
29-
from numba_dpex.utils import address_space, calling_conv
3030

3131
from . import codegen
3232
from .overloads._registry import registry as spirv_registry
@@ -367,7 +367,7 @@ def declare_function(self, module, fndesc):
367367
if not self.enable_debuginfo:
368368
fn.attributes.add("alwaysinline")
369369
ret = super().declare_function(module, fndesc)
370-
ret.calling_convention = calling_conv.CC_SPIR_FUNC
370+
ret.calling_convention = CC_SPIR_FUNC
371371
return ret
372372

373373
def insert_const_string(self, mod, string):
@@ -392,15 +392,15 @@ def insert_const_string(self, mod, string):
392392
if gv is None:
393393
# Not defined yet
394394
gv = cgutils.add_global_variable(
395-
mod, text.type, name=name, addrspace=address_space.GENERIC
395+
mod, text.type, name=name, addrspace=address_space.GENERIC.value
396396
)
397397
gv.linkage = "internal"
398398
gv.global_constant = True
399399
gv.initializer = text
400400

401401
# Cast to a i8* pointer
402402
charty = gv.type.pointee.element
403-
return gv.bitcast(charty.as_pointer(address_space.GENERIC))
403+
return gv.bitcast(charty.as_pointer(address_space.GENERIC.value))
404404

405405
def addrspacecast(self, builder, src, addrspace):
406406
"""Insert an LLVM addressspace cast instruction into the module.

numba_dpex/ocl/ocldecl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import numba_dpex as dpex
1717
from numba_dpex.core.types import Array
18-
from numba_dpex.utils import address_space
18+
from numba_dpex.kernel_api import AddressSpace as address_space
1919

2020
registry = Registry()
2121
intrinsic = registry.register
@@ -158,7 +158,7 @@ def typer(shape, dtype):
158158
dtype=nb_dtype,
159159
ndim=ndim,
160160
layout="C",
161-
addrspace=address_space.LOCAL,
161+
addrspace=address_space.LOCAL.value,
162162
)
163163

164164
return typer
@@ -201,7 +201,7 @@ def typer(shape, dtype):
201201
dtype=nb_dtype,
202202
ndim=ndim,
203203
layout="C",
204-
addrspace=address_space.PRIVATE,
204+
addrspace=address_space.PRIVATE.value,
205205
)
206206

207207
return typer

numba_dpex/ocl/oclimpl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from numba_dpex import spirv_kernel_target
1717
from numba_dpex.core import config
1818
from numba_dpex.core.types import Array
19+
from numba_dpex.kernel_api import AddressSpace as address_space
1920
from numba_dpex.kernel_api_impl.spirv.codegen import SPIR_DATA_LAYOUT
2021
from numba_dpex.ocl.atomics import atomic_helper
21-
from numba_dpex.utils import address_space
2222

2323
from . import stubs
2424
from ._declare_function import _declare_function
@@ -301,7 +301,7 @@ def dpex_private_array_integer(context, builder, sig, args):
301301
shape=(length,),
302302
dtype=dtype,
303303
symbol_name="_dpex_pmem",
304-
addrspace=address_space.PRIVATE,
304+
addrspace=address_space.PRIVATE.value,
305305
)
306306

307307

@@ -316,7 +316,7 @@ def dpex_private_array_tuple(context, builder, sig, args):
316316
shape=shape,
317317
dtype=dtype,
318318
symbol_name="_dpex_pmem",
319-
addrspace=address_space.PRIVATE,
319+
addrspace=address_space.PRIVATE.value,
320320
)
321321

322322

@@ -330,7 +330,7 @@ def dpex_local_array_integer(context, builder, sig, args):
330330
shape=(length,),
331331
dtype=dtype,
332332
symbol_name="_dpex_lmem",
333-
addrspace=address_space.LOCAL,
333+
addrspace=address_space.LOCAL.value,
334334
)
335335

336336

@@ -345,7 +345,7 @@ def dpex_local_array_tuple(context, builder, sig, args):
345345
shape=shape,
346346
dtype=dtype,
347347
symbol_name="_dpex_lmem",
348-
addrspace=address_space.LOCAL,
348+
addrspace=address_space.LOCAL.value,
349349
)
350350

351351

@@ -358,7 +358,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
358358
lldtype = context.get_data_type(dtype)
359359
laryty = llvmir.ArrayType(lldtype, elemcount)
360360

361-
if addrspace == address_space.LOCAL:
361+
if addrspace == address_space.LOCAL.value:
362362
lmod = builder.module
363363

364364
# Create global variable in the requested address-space
@@ -374,7 +374,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
374374
if dtype not in types.number_domain:
375375
raise TypeError("unsupported type: %s" % dtype)
376376

377-
elif addrspace == address_space.PRIVATE:
377+
elif addrspace == address_space.PRIVATE.value:
378378
gvmem = cgutils.alloca_once(builder, laryty, name=symbol_name)
379379
else:
380380
raise NotImplementedError("addrspace {addrspace}".format(**locals()))
@@ -397,7 +397,7 @@ def _make_array(
397397
dtype,
398398
shape,
399399
layout="C",
400-
addrspace=address_space.GENERIC,
400+
addrspace=address_space.GENERIC.value,
401401
):
402402
ndim = len(shape)
403403
# Create array object

numba_dpex/printimpl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from numba.core import cgutils, types
99
from numba.core.imputils import Registry
1010

11-
from numba_dpex.utils import address_space
11+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
1212

1313
registry = Registry()
1414
lower = registry.lower
1515

1616

1717
def declare_print(lmod):
1818
voidptrty = llvmir.PointerType(
19-
llvmir.IntType(8), addrspace=address_space.GENERIC
19+
llvmir.IntType(8), addrspace=address_space.GENERIC.value
2020
)
2121
printfty = llvmir.FunctionType(
2222
llvmir.IntType(32), [voidptrty], var_arg=True

numba_dpex/tests/core/test_itanium_mangler_extension.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numba.core import types
88

99
import numba_dpex.core.utils.itanium_mangler as itanium_mangler
10-
from numba_dpex.utils import address_space
10+
from numba_dpex.kernel_api import AddressSpace as address_space
1111

1212
list_of_dtypes = [
1313
(int32, "i"),
@@ -25,10 +25,10 @@ def dtypes(request):
2525

2626

2727
list_of_addrspaces = [
28-
(address_space.PRIVATE, "3AS0"),
29-
(address_space.GLOBAL, "3AS1"),
30-
(address_space.LOCAL, "3AS3"),
31-
(address_space.GENERIC, "3AS4"),
28+
(address_space.PRIVATE.value, "3AS0"),
29+
(address_space.GLOBAL.value, "3AS1"),
30+
(address_space.LOCAL.value, "3AS3"),
31+
(address_space.GENERIC.value, "3AS4"),
3232
]
3333

3434

0 commit comments

Comments
 (0)