Skip to content

Commit cf4b631

Browse files
author
Diptorup Deb
authored
Merge pull request #1317 from IntelPython/refactor/migrate_int_enum_literal
Migrate IntEnumLiteral into core.
2 parents ccb7606 + 269e08e commit cf4b631

File tree

11 files changed

+75
-83
lines changed

11 files changed

+75
-83
lines changed

numba_dpex/_kernel_api_impl/spirv/target.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from numba.core.registry import cpu_target
1616
from numba.core.target_extension import GPU, target_registry
1717
from numba.core.types import Array as NpArrayType
18+
from numba.core.types.scalars import IntEnumClass
1819

1920
from numba_dpex.core.datamodel.models import _init_data_model_manager
2021
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
2122
from numba_dpex.core.typeconv import to_usm_ndarray
22-
from numba_dpex.core.types import USMNdArray
23+
from numba_dpex.core.types import IntEnumLiteral, USMNdArray
2324
from numba_dpex.core.utils import get_info_from_suai
25+
from numba_dpex.kernel_api.flag_enum import FlagEnum
2426
from numba_dpex.utils import address_space, calling_conv
2527

2628
from . import codegen
@@ -64,6 +66,37 @@ class SPIRVTypingContext(typing.BaseContext):
6466
6567
"""
6668

69+
def resolve_value_type(self, val):
70+
"""
71+
Return the numba type of a Python value that is being used
72+
as a runtime constant.
73+
ValueError is raised for unsupported types.
74+
"""
75+
76+
typ = super().resolve_value_type(val)
77+
78+
if isinstance(typ, IntEnumClass) and issubclass(val, FlagEnum):
79+
typ = IntEnumLiteral(val)
80+
81+
return typ
82+
83+
def resolve_getattr(self, typ, attr):
84+
"""
85+
Resolve getting the attribute *attr* (a string) on the Numba type.
86+
The attribute's type is returned, or None if resolution failed.
87+
"""
88+
retty = None
89+
90+
if isinstance(typ, IntEnumLiteral):
91+
try:
92+
attrval = getattr(typ.literal_value, attr).value
93+
retty = types.IntegerLiteral(attrval)
94+
except ValueError:
95+
pass
96+
else:
97+
retty = super().resolve_getattr(typ, attr)
98+
return retty
99+
67100
def resolve_argument_type(self, val):
68101
"""Return the Numba type of a Python value used as a function argument.
69102
@@ -269,6 +302,22 @@ def init(self):
269302

270303
self.ufunc_db = _dpnp_ufunc_db
271304

305+
def get_getattr(self, typ, attr):
306+
"""
307+
Overrides the get_getattr function to provide an implementation for
308+
getattr call on an IntegerEnumLiteral type.
309+
"""
310+
311+
if isinstance(typ, IntEnumLiteral):
312+
# pylint: disable=W0613
313+
def enum_literal_getattr_imp(context, builder, typ, val, attr):
314+
enum_attr_value = getattr(typ.literal_value, attr).value
315+
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)
316+
317+
return enum_literal_getattr_imp
318+
319+
return super().get_getattr(typ, attr)
320+
272321
def create_module(self, name):
273322
return self._internal_codegen._create_empty_module(name)
274323

numba_dpex/core/datamodel/models.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from llvmlite import ir as llvmir
56
from numba.core import datamodel, types
67
from numba.core.datamodel.models import PrimitiveModel, StructModel
78
from numba.core.extending import register_model
@@ -14,6 +15,7 @@
1415
DpctlSyclEvent,
1516
DpctlSyclQueue,
1617
DpnpNdArray,
18+
IntEnumLiteral,
1719
NdRangeType,
1820
RangeType,
1921
USMNdArray,
@@ -55,6 +57,17 @@ def __init__(self, dmm, fe_type):
5557
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)
5658

5759

60+
class IntEnumLiteralModel(PrimitiveModel):
61+
"""Representation of an object of LiteralIntEnum type using Numba's
62+
PrimitiveModel that can be represented natively in the target in all
63+
usage contexts.
64+
"""
65+
66+
def __init__(self, dmm, fe_type):
67+
be_type = llvmir.IntType(fe_type.bitwidth)
68+
super().__init__(dmm, fe_type, be_type)
69+
70+
5871
class USMArrayDeviceModel(StructModel):
5972
"""A data model to represent a usm array type in the LLVM IR generated for a
6073
device-only kernel function.
@@ -237,7 +250,7 @@ def flattened_field_count(self):
237250

238251

239252
def _init_data_model_manager() -> datamodel.DataModelManager:
240-
"""Initializes a DpexKernelTarget-specific data model manager.
253+
"""Initializes a data model manager used by the SPRIVTarget.
241254
242255
SPIRV kernel functions for certain types of devices require an explicit
243256
address space qualifier for pointers. For OpenCL HD Graphics
@@ -252,8 +265,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
252265
a dpnp.ndarray object can be passed to any other regular function.
253266
254267
Returns:
255-
DataModelManager: A numba-dpex DpexKernelTarget-specific data model
256-
manager
268+
DataModelManager: A numba-dpex SPIRVTarget-specific data model manager
257269
"""
258270
dmm = datamodel.default_manager.copy()
259271
dmm.register(types.CPointer, GenericPointerModel)
@@ -271,6 +283,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
271283
# model manager. The dpex_data_model_manager is used by the DpexKernelTarget
272284
dmm.register(DpctlSyclQueue, SyclQueueModel)
273285

286+
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
287+
274288
return dmm
275289

276290

numba_dpex/core/types/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .array_type import Array
66
from .dpctl_types import DpctlSyclEvent, DpctlSyclQueue
77
from .dpnp_ndarray_type import DpnpNdArray
8+
from .kernel_api.literal_intenum import IntEnumLiteral
89
from .kernel_api.ranges import NdRangeType, RangeType
910
from .numba_types_short_names import (
1011
b1,
@@ -36,8 +37,9 @@
3637
"DpctlSyclQueue",
3738
"DpctlSyclEvent",
3839
"DpnpNdArray",
39-
"RangeType",
40+
"IntEnumLiteral",
4041
"NdRangeType",
42+
"RangeType",
4143
"USMNdArray",
4244
"none",
4345
"boolean",
@@ -57,6 +59,6 @@
5759
"f8",
5860
"float_",
5961
"double",
60-
"void",
6162
"usm_ndarray",
63+
"void",
6264
]

numba_dpex/experimental/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from .decorators import device_func, kernel
2222
from .launcher import call_kernel, call_kernel_async
23-
from .literal_intenum_type import IntEnumLiteral
2423
from .models import *
2524
from .types import KernelDispatcherType
2625

@@ -41,6 +40,5 @@ def dpex_dispatcher_const(context):
4140
"kernel",
4241
"call_kernel",
4342
"call_kernel_async",
44-
"IntEnumLiteral",
4543
"SPIRVKernelDispatcher",
4644
]

numba_dpex/experimental/models.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
numba_dpex.experimental module.
77
"""
88

9-
from llvmlite import ir as llvmir
109
from numba.core import types
1110
from numba.core.datamodel import DataModelManager, models
12-
from numba.core.datamodel.models import PrimitiveModel, StructModel
11+
from numba.core.datamodel.models import StructModel
1312
from numba.core.extending import register_model
1413

1514
import numba_dpex.core.datamodel.models as dpex_core_models
@@ -20,7 +19,6 @@
2019
)
2120

2221
from .dpcpp_types import AtomicRefType
23-
from .literal_intenum_type import IntEnumLiteral
2422
from .types import KernelDispatcherType
2523

2624

@@ -37,17 +35,6 @@ def __init__(self, dmm, fe_type):
3735
super().__init__(dmm, fe_type, members)
3836

3937

40-
class IntEnumLiteralModel(PrimitiveModel):
41-
"""Representation of an object of LiteralIntEnum type using Numba's
42-
PrimitiveModel that can be represented natively in the target in all
43-
usage contexts.
44-
"""
45-
46-
def __init__(self, dmm, fe_type):
47-
be_type = llvmir.IntType(fe_type.bitwidth)
48-
super().__init__(dmm, fe_type, be_type)
49-
50-
5138
class EmptyStructModel(StructModel):
5239
"""Data model that does not take space. Intended to be used with types that
5340
are presented only at typing stage and not represented physically."""
@@ -71,7 +58,6 @@ def _init_exp_data_model_manager() -> DataModelManager:
7158
dmm = dpex_core_models.dpex_data_model_manager.copy()
7259

7360
# Register the types and data model in the DpexExpTargetContext
74-
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
7561
dmm.register(AtomicRefType, AtomicRefModel)
7662

7763
# Register the GroupType type

numba_dpex/experimental/target.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,15 @@
88

99
from functools import cached_property
1010

11-
from llvmlite import ir as llvmir
12-
from numba.core import types
1311
from numba.core.descriptors import TargetDescriptor
1412
from numba.core.target_extension import GPU, target_registry
15-
from numba.core.types.scalars import IntEnumClass
1613

1714
from numba_dpex._kernel_api_impl.spirv.target import (
1815
SPIRVTargetContext,
1916
SPIRVTypingContext,
2017
)
2118
from numba_dpex.core.descriptor import DpexTargetOptions
2219
from numba_dpex.experimental.models import exp_dmm
23-
from numba_dpex.kernel_api.flag_enum import FlagEnum
24-
25-
from .literal_intenum_type import IntEnumLiteral
2620

2721

2822
# pylint: disable=R0903
@@ -45,37 +39,6 @@ class DpexExpKernelTypingContext(SPIRVTypingContext):
4539
are stable enough to be migrated to DpexKernelTypingContext.
4640
"""
4741

48-
def resolve_value_type(self, val):
49-
"""
50-
Return the numba type of a Python value that is being used
51-
as a runtime constant.
52-
ValueError is raised for unsupported types.
53-
"""
54-
55-
typ = super().resolve_value_type(val)
56-
57-
if isinstance(typ, IntEnumClass) and issubclass(val, FlagEnum):
58-
typ = IntEnumLiteral(val)
59-
60-
return typ
61-
62-
def resolve_getattr(self, typ, attr):
63-
"""
64-
Resolve getting the attribute *attr* (a string) on the Numba type.
65-
The attribute's type is returned, or None if resolution failed.
66-
"""
67-
retty = None
68-
69-
if isinstance(typ, IntEnumLiteral):
70-
try:
71-
attrval = getattr(typ.literal_value, attr).value
72-
retty = types.IntegerLiteral(attrval)
73-
except ValueError:
74-
pass
75-
else:
76-
retty = super().resolve_getattr(typ, attr)
77-
return retty
78-
7942

8043
# pylint: disable=W0223
8144
# FIXME: Remove the pylint disablement once we add an override for
@@ -95,22 +58,6 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
9558
super().__init__(typingctx, target)
9659
self.data_model_manager = exp_dmm
9760

98-
def get_getattr(self, typ, attr):
99-
"""
100-
Overrides the get_getattr function to provide an implementation for
101-
getattr call on an IntegerEnumLiteral type.
102-
"""
103-
104-
if isinstance(typ, IntEnumLiteral):
105-
# pylint: disable=W0613
106-
def enum_literal_getattr_imp(context, builder, typ, val, attr):
107-
enum_attr_value = getattr(typ.literal_value, attr).value
108-
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)
109-
110-
return enum_literal_getattr_imp
111-
112-
return super().get_getattr(typ, attr)
113-
11461

11562
class DpexExpKernelTarget(TargetDescriptor):
11663
"""

numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py renamed to numba_dpex/tests/core/types/IntEnumLiteral/test_type_creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
10-
from numba_dpex.experimental import IntEnumLiteral
10+
from numba_dpex.core.types import IntEnumLiteral
1111
from numba_dpex.kernel_api.flag_enum import FlagEnum
1212

1313

0 commit comments

Comments
 (0)