Skip to content

Commit 9bb4a47

Browse files
committed
Remove direct datamodel.default_manager use.
1 parent af226f3 commit 9bb4a47

File tree

10 files changed

+62
-57
lines changed

10 files changed

+62
-57
lines changed

numba_dpex/core/boxing/ranges.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from contextlib import ExitStack
66

77
from numba.core import cgutils, types
8-
from numba.core.datamodel import default_manager
98
from numba.extending import NativeValue, box, unbox
109

1110
from numba_dpex.core.types import NdRangeType, RangeType
@@ -78,7 +77,9 @@ def unbox_ndrange(typ, obj, c):
7877
].value
7978
local_range_struct = ndrange_attr_native_value_map["local_range"].value
8079

81-
range_datamodel = default_manager.lookup(RangeType(typ.ndim))
80+
range_datamodel = c.context.data_model_manager.lookup(
81+
RangeType(typ.ndim)
82+
)
8283
ndrange_struct.ndim = c.builder.extract_value(
8384
global_range_struct,
8485
range_datamodel.get_field_position("ndim"),

numba_dpex/core/boxing/usm_ndarray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from contextlib import ExitStack
66

77
from numba.core import cgutils, types
8-
from numba.core.datamodel import default_manager
98
from numba.core.errors import NumbaNotImplementedError
109
from numba.extending import NativeValue, box, unbox
1110
from numba.np import numpy_support

numba_dpex/core/datamodel/models.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from llvmlite import ir as llvmir
66
from numba.core import datamodel, types
77
from numba.core.datamodel.models import OpaqueModel, PrimitiveModel, StructModel
8-
from numba.core.extending import register_model
98

109
from numba_dpex.core.exceptions import UnreachableError
1110
from numba_dpex.core.types.kernel_api.atomic_ref import AtomicRefType
@@ -316,7 +315,7 @@ def __init__(self, dmm, fe_type):
316315
super().__init__(dmm, fe_type, members)
317316

318317

319-
def _init_data_model_manager() -> datamodel.DataModelManager:
318+
def _init_kernel_data_model_manager() -> datamodel.DataModelManager:
320319
"""Initializes a data model manager used by the SPRIVTarget.
321320
322321
SPIRV kernel functions for certain types of devices require an explicit
@@ -370,43 +369,50 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
370369
return dmm
371370

372371

373-
dpex_data_model_manager = _init_data_model_manager()
372+
def _init_dpjit_data_model_manager() -> datamodel.DataModelManager:
373+
# TODO: copy manager
374+
dmm = datamodel.default_manager
374375

376+
# Register the USMNdArray type to USMArrayHostModel in numba's default data
377+
# model manager
378+
dmm.register(USMNdArray, USMArrayHostModel)
375379

376-
# Register the USMNdArray type to USMArrayDeviceModel in numba's default data
377-
# model manager
378-
register_model(USMNdArray)(USMArrayHostModel)
380+
# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
381+
# model manager
382+
dmm.register(DpnpNdArray, USMArrayHostModel)
379383

380-
# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
381-
# model manager
382-
register_model(DpnpNdArray)(USMArrayHostModel)
384+
# Register the DpctlSyclQueue type
385+
dmm.register(DpctlSyclQueue, SyclQueueModel)
383386

384-
# Register the DpctlSyclQueue type
385-
register_model(DpctlSyclQueue)(SyclQueueModel)
387+
# Register the DpctlSyclEvent type
388+
dmm.register(DpctlSyclEvent, SyclEventModel)
386389

387-
# Register the DpctlSyclEvent type
388-
register_model(DpctlSyclEvent)(SyclEventModel)
390+
# Register the RangeType type
391+
dmm.register(RangeType, RangeModel)
389392

390-
# Register the RangeType type
391-
register_model(RangeType)(RangeModel)
393+
# Register the NdRangeType type
394+
dmm.register(NdRangeType, NdRangeModel)
392395

393-
# Register the NdRangeType type
394-
register_model(NdRangeType)(NdRangeModel)
396+
# Register the GroupType type
397+
dmm.register(GroupType, EmptyStructModel)
395398

396-
# Register the GroupType type
397-
register_model(GroupType)(EmptyStructModel)
399+
# Register the ItemType type
400+
dmm.register(ItemType, EmptyStructModel)
398401

399-
# Register the ItemType type
400-
register_model(ItemType)(EmptyStructModel)
402+
# Register the NdItemType type
403+
dmm.register(NdItemType, EmptyStructModel)
404+
405+
# Register the MDLocalAccessorType type
406+
dmm.register(DpctlMDLocalAccessorType, DpctlMDLocalAccessorModel)
401407

402-
# Register the NdItemType type
403-
register_model(NdItemType)(EmptyStructModel)
408+
# Register the LocalAccessorType type
409+
dmm.register(LocalAccessorType, LocalAccessorModel)
404410

405-
# Register the MDLocalAccessorType type
406-
register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel)
411+
# Register the KernelDispatcherType type
412+
dmm.register(KernelDispatcherType, OpaqueModel)
413+
414+
return dmm
407415

408-
# Register the LocalAccessorType type
409-
register_model(LocalAccessorType)(LocalAccessorModel)
410416

411-
# Register the KernelDispatcherType type
412-
register_model(KernelDispatcherType)(OpaqueModel)
417+
dpex_data_model_manager = _init_kernel_data_model_manager()
418+
dpjit_data_model_manager = _init_dpjit_data_model_manager()

numba_dpex/core/kernel_interface/ranges_overloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from llvmlite import ir as llvmir
66
from numba.core import cgutils, errors, types
7-
from numba.core.datamodel import default_manager
87
from numba.extending import intrinsic, overload
98

109
from numba_dpex.kernel_api import NdRange, Range
@@ -60,11 +59,12 @@ def _intrin_ndrange_alloc(
6059
ty_local_range,
6160
ty_ndrange,
6261
)
63-
range_datamodel = default_manager.lookup(ty_global_range)
6462

6563
def codegen(context, builder, sig, args):
6664
typ = sig.return_type
6765

66+
range_datamodel = context.data_model_manager.lookup(ty_global_range)
67+
6868
global_range, local_range, _ = args
6969
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
7070
ndrange_struct.ndim = llvmir.Constant(

numba_dpex/core/targets/dpjit_target.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numba.core.imputils import Registry
1414
from numba.core.target_extension import CPU, target_registry
1515

16+
from numba_dpex.core.datamodel.models import _init_dpjit_data_model_manager
1617
from numba_dpex.dpnp_iface import dpnp_ufunc_db
1718

1819

@@ -49,6 +50,8 @@ def init(self):
4950
self.lower_extensions = {}
5051
super().init()
5152

53+
self.data_model_manager = _init_dpjit_data_model_manager()
54+
5255
# TODO: initialize nrt once switched to nrt from drt. Most likely we
5356
# call it somewhere. Double check.
5457
# https://github.com/IntelPython/numba-dpex/issues/1175

numba_dpex/dpctl_iface/_intrinsic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import dpctl
66
from llvmlite.ir import IRBuilder
77
from numba import types
8-
from numba.core.datamodel import default_manager
98
from numba.extending import intrinsic, overload, overload_method
109

1110
import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
@@ -45,7 +44,7 @@ def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):
4544

4645
# defines the custom code generation
4746
def codegen(context, builder, signature, args):
48-
sycl_event_dm = default_manager.lookup(ty_event)
47+
sycl_event_dm = context.data_model_manager.lookup(ty_event)
4948
event_ref = builder.extract_value(
5049
args[0],
5150
sycl_event_dm.get_field_position("event_ref"),

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numba.core.types.scalars import IntEnumClass
2121
from numba.core.typing import cmathdecl, enumdecl
2222

23-
from numba_dpex.core.datamodel.models import _init_data_model_manager
23+
from numba_dpex.core.datamodel.models import _init_kernel_data_model_manager
2424
from numba_dpex.core.types import IntEnumLiteral
2525
from numba_dpex.core.typing import dpnpdecl
2626
from numba_dpex.kernel_api.flag_enum import FlagEnum
@@ -154,7 +154,7 @@ def init(self):
154154
)
155155

156156
# Override data model manager to SPIR model
157-
self.data_model_manager = _init_data_model_manager()
157+
self.data_model_manager = _init_kernel_data_model_manager()
158158
self.extra_compile_options = {}
159159

160160
_lazy_init_dpnp_db()

numba_dpex/tests/core/types/DpctlSyclEvent/test_models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import dpctl
6-
from numba import types
7-
from numba.core.datamodel import default_manager, models
5+
from numba.core.datamodel import models
86

97
from numba_dpex.core.datamodel.models import (
108
SyclEventModel,
11-
dpex_data_model_manager,
9+
dpjit_data_model_manager,
1210
)
1311
from numba_dpex.core.types.dpctl_types import DpctlSyclEvent
1412

@@ -18,7 +16,7 @@ def test_model_for_DpctlSyclEvent():
1816
default data model manager.
1917
"""
2018
sycl_event = DpctlSyclEvent()
21-
default_model = default_manager.lookup(sycl_event)
19+
default_model = dpjit_data_model_manager.lookup(sycl_event)
2220
assert isinstance(default_model, SyclEventModel)
2321

2422

numba_dpex/tests/core/types/range_types/test_data_model.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import pytest
6-
from numba.core.datamodel import default_manager
7-
from numba.core.registry import cpu_target
86

97
from numba_dpex.core.datamodel.models import (
108
NdRangeModel,
119
RangeModel,
1210
dpex_data_model_manager,
11+
dpjit_data_model_manager,
1312
)
14-
from numba_dpex.core.descriptor import dpex_kernel_target
13+
from numba_dpex.core.descriptor import dpex_kernel_target, dpex_target
1514
from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType
1615

1716
rfields = ["ndim", "dim0", "dim1", "dim2"]
@@ -30,8 +29,8 @@ def test_datamodel_registration():
3029
dpex_data_model_manager.lookup(range_ty)
3130
dpex_data_model_manager.lookup(ndrange_ty)
3231

33-
default_range_model = default_manager.lookup(range_ty)
34-
default_ndrange_model = default_manager.lookup(ndrange_ty)
32+
default_range_model = dpjit_data_model_manager.lookup(range_ty)
33+
default_ndrange_model = dpjit_data_model_manager.lookup(ndrange_ty)
3534

3635
assert isinstance(default_range_model, RangeModel)
3736
assert isinstance(default_ndrange_model, NdRangeModel)
@@ -43,7 +42,7 @@ def test_range_model_fields(field):
4342
RangeType
4443
"""
4544
range_ty = RangeType(ndim=1)
46-
dm = default_manager.lookup(range_ty)
45+
dm = dpjit_data_model_manager.lookup(range_ty)
4746
try:
4847
dm.get_field_position(field)
4948
except:
@@ -56,7 +55,7 @@ def test_ndrange_model_fields(field):
5655
NdRangeType
5756
"""
5857
ndrange_ty = NdRangeType(ndim=1)
59-
dm = default_manager.lookup(ndrange_ty)
58+
dm = dpjit_data_model_manager.lookup(ndrange_ty)
6059
try:
6160
dm.get_field_position(field)
6261
except:
@@ -69,15 +68,14 @@ def test_flattened_member_count(range_type):
6968
flattened args generated by the CpuTarget's ArgPacker.
7069
"""
7170

72-
cputargetctx = cpu_target.target_context
73-
kerneltargetctx = dpex_kernel_target.target_context
74-
dpex_dmm = kerneltargetctx.data_model_manager
71+
dpjit_target_ctx = dpex_target.target_context
72+
dpjit_dmm = dpjit_target_ctx.data_model_manager
7573

7674
for ndim in range(1, 3):
7775
dty = range_type(ndim)
7876
argty_tuple = tuple([dty])
79-
datamodel = dpex_dmm.lookup(dty)
77+
datamodel = dpjit_dmm.lookup(dty)
8078
num_flattened_args = datamodel.flattened_field_count
81-
ap = cputargetctx.get_arg_packer(argty_tuple)
79+
ap = dpjit_target_ctx.get_arg_packer(argty_tuple)
8280

8381
assert num_flattened_args == len(ap._be_args)

numba_dpex/tests/core/types/test_array_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
import pytest
66
from numba import types
7-
from numba.core.datamodel import default_manager, models
7+
from numba.core.datamodel import models
88
from numba.core.registry import cpu_target
99

1010
from numba_dpex.core.datamodel.models import (
1111
USMArrayDeviceModel,
1212
USMArrayHostModel,
1313
dpex_data_model_manager,
14+
dpjit_data_model_manager,
1415
)
1516
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray, USMNdArray
1617

@@ -32,7 +33,7 @@ def test_model_for_array(nd_array):
3233
"""
3334
device_model = dpex_data_model_manager.lookup(nd_array)
3435
assert isinstance(device_model, USMArrayDeviceModel)
35-
host_model = default_manager.lookup(nd_array)
36+
host_model = dpjit_data_model_manager.lookup(nd_array)
3637
assert isinstance(host_model, USMArrayHostModel)
3738

3839

0 commit comments

Comments
 (0)