Skip to content

Commit 5fb6093

Browse files
authored
Merge pull request #1423 from IntelPython/refactor/keep-utils-in-one-place
Move all utility functions into the core.utils module
2 parents ebc6c32 + 0ad3dba commit 5fb6093

21 files changed

+55
-415
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DpctlMDLocalAccessorType,
1818
LocalAccessorType,
1919
)
20-
from numba_dpex.utils import address_space
20+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
2121

2222
from ..types import (
2323
Array,
@@ -62,7 +62,7 @@ def __init__(self, dmm, fe_type):
6262
adrsp = (
6363
fe_type.addrspace
6464
if fe_type.addrspace is not None
65-
else address_space.GLOBAL
65+
else address_space.GLOBAL.value
6666
)
6767
be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
6868
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)

numba_dpex/core/kernel_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
NdItemType,
2727
)
2828
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
29-
from numba_dpex.core.utils import kernel_launcher as kl
29+
from numba_dpex.core.utils import call_kernel_builder as kl
3030
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
3131
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference
3232
from numba_dpex.kernel_api_impl.spirv.dispatcher import (

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ReductionHelper,
2121
ReductionKernelVariables,
2222
)
23-
from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder
23+
from numba_dpex.core.utils.call_kernel_builder import KernelLaunchIRBuilder
2424
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2525

2626
from ..exceptions import UnsupportedParforError

numba_dpex/core/parfors/reduction_helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from numba.parfors import parfor
1818
from numba.parfors.parfor_lowering_utils import ParforLoweringBuilder
1919

20-
from numba_dpex import utils
2120
from numba_dpex.core.datamodel.models import (
2221
dpex_data_model_manager as kernel_dmm,
2322
)
24-
from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder
23+
from numba_dpex.core.utils.call_kernel_builder import KernelLaunchIRBuilder
24+
from numba_dpex.core.utils.cgutils_extra import get_llvm_type
2525
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2626

2727
from ..types.dpnp_ndarray_type import DpnpNdArray
@@ -431,11 +431,11 @@ def copy_final_sum_to_host(self, parfor_kernel):
431431

432432
dest = builder.bitcast(
433433
lowerer.getvar(redvar),
434-
utils.get_llvm_type(context=context, type=types.voidptr),
434+
get_llvm_type(context=context, type=types.voidptr),
435435
)
436436
src = builder.bitcast(
437437
builder.load(array_attr),
438-
utils.get_llvm_type(context=context, type=types.voidptr),
438+
get_llvm_type(context=context, type=types.voidptr),
439439
)
440440

441441
args = [

numba_dpex/core/types/kernel_api/local_accessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from numba.np import numpy_support
99

1010
from numba_dpex.core.types import USMNdArray
11-
from numba_dpex.utils import address_space as AddressSpace
11+
from numba_dpex.kernel_api.memory_enums import AddressSpace
1212

1313

1414
class DpctlMDLocalAccessorType(Type):
@@ -36,14 +36,14 @@ def __init__(self, ndim, dtype):
3636

3737
type_name = (
3838
f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, "
39-
f"address_space={AddressSpace.LOCAL})"
39+
f"address_space={AddressSpace.LOCAL.value})"
4040
)
4141

4242
super().__init__(
4343
ndim=ndim,
4444
layout="C",
4545
dtype=parsed_dtype,
46-
addrspace=AddressSpace.LOCAL,
46+
addrspace=AddressSpace.LOCAL.value,
4747
name=type_name,
4848
)
4949

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from numba.np.numpy_support import from_dtype
1414

1515
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
16-
from numba_dpex.utils.constants import address_space
16+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
1717

1818

1919
class USMNdArray(Array):
@@ -30,7 +30,7 @@ def __init__(
3030
readonly=False,
3131
name=None,
3232
aligned=True,
33-
addrspace=address_space.GLOBAL,
33+
addrspace=address_space.GLOBAL.value,
3434
):
3535
if (
3636
queue is not None

numba_dpex/core/typing/typeof.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from numba.np import numpy_support
1010

1111
from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem
12+
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
1213
from numba_dpex.kernel_api.ranges import NdRange, Range
13-
from numba_dpex.utils.constants import address_space
1414

1515
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
1616
from ..types.dpnp_ndarray_type import DpnpNdArray
@@ -59,7 +59,7 @@ def _array_typeof_helper(val, array_class_type):
5959
readonly=readonly,
6060
usm_type=usm_type,
6161
queue=ty_queue,
62-
addrspace=address_space.GLOBAL,
62+
addrspace=address_space.GLOBAL.value,
6363
)
6464

6565

numba_dpex/core/utils/kernel_launcher.py renamed to numba_dpex/core/utils/call_kernel_builder.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717
from numba.core.datamodel import DataModelManager
1818
from numba.core.types.containers import UniTuple
1919

20-
from numba_dpex import utils
2120
from numba_dpex.core import config
2221
from numba_dpex.core.exceptions import UnreachableError
2322
from numba_dpex.core.runtime.context import DpexRTContext
2423
from numba_dpex.core.types import USMNdArray
2524
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
2625
from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType
26+
from numba_dpex.core.utils import cgutils_extra
2727
from numba_dpex.core.utils.kernel_flattened_args_builder import (
2828
KernelFlattenedArgsBuilder,
2929
)
3030
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
31-
from numba_dpex.utils import create_null_ptr
3231

3332
MAX_SIZE_OF_SYCL_RANGE = 3
3433

@@ -142,10 +141,15 @@ def _build_nullptr(self):
142141
143142
Returns: An LLVM Value storing a null pointer
144143
"""
145-
zero = cgutils.alloca_once(self.builder, utils.LLVMTypes.int64_t)
144+
zero = cgutils.alloca_once(
145+
self.builder, cgutils_extra.LLVMTypes.int64_t
146+
)
146147
self.builder.store(self.context.get_constant(types.int64, 0), zero)
147148
return self.builder.bitcast(
148-
zero, utils.get_llvm_type(context=self.context, type=types.voidptr)
149+
zero,
150+
cgutils_extra.get_llvm_type(
151+
context=self.context, type=types.voidptr
152+
),
149153
)
150154

151155
# TODO: remove, not part of the builder
@@ -159,7 +163,9 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
159163
# Allocate a stack var to store the queue created from the filter string
160164
sycl_queue_val = cgutils.alloca_once(
161165
self.builder,
162-
utils.get_llvm_type(context=self.context, type=types.voidptr),
166+
cgutils_extra.get_llvm_type(
167+
context=self.context, type=types.voidptr
168+
),
163169
)
164170
# Insert a global constant to store the filter string
165171
device = self.context.insert_const_string(
@@ -256,8 +262,8 @@ def _create_sycl_range(self, idx_range):
256262
"""
257263
int64_range = [
258264
(
259-
self.builder.sext(rext, utils.LLVMTypes.int64_t)
260-
if rext.type != utils.LLVMTypes.int64_t
265+
self.builder.sext(rext, cgutils_extra.LLVMTypes.int64_t)
266+
if rext.type != cgutils_extra.LLVMTypes.int64_t
261267
else rext
262268
)
263269
for rext in idx_range
@@ -334,7 +340,7 @@ def set_kernel_from_spirv(
334340
)
335341
else:
336342
spv_compiler_options = self.builder.load(
337-
create_null_ptr(self.builder, self.context)
343+
cgutils_extra.create_null_ptr(self.builder, self.context)
338344
)
339345

340346
# build_or_get_kernel steals reference to context and device cause it

numba_dpex/core/utils/kernel_flattened_args_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from numba.core import cgutils, types
1717
from numba.core.cpu import CPUContext
1818

19-
from numba_dpex import utils
2019
from numba_dpex.core.types import USMNdArray
2120
from numba_dpex.core.types.kernel_api.local_accessor import (
2221
DpctlMDLocalAccessorType,
2322
LocalAccessorType,
2423
)
24+
from numba_dpex.core.utils import cgutils_extra as utils
2525
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
2626

2727

0 commit comments

Comments
 (0)