Skip to content

Commit d0bedd2

Browse files
author
Diptorup Deb
committed
Adds a new literal type to store IntEnum as Literal types.
- Adds a new IntEnumLiteral type with corresponding data model into the DpexExpKernelTargetContext. The type is used to pass in or define an IntEnum flag as an Integer literal inside a kernel function.
1 parent f1be213 commit d0bedd2

File tree

6 files changed

+168
-2
lines changed

6 files changed

+168
-2
lines changed

numba_dpex/core/exceptions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,15 @@ def __init__(self, extra_msg=None) -> None:
386386
if extra_msg:
387387
self.message += " due to " + extra_msg
388388
super().__init__(self.message)
389+
390+
391+
class IllegalIntEnumLiteralValueError(Exception):
392+
"""Exception raised when an IntEnumLiteral is attempted to be created from
393+
a non FlagEnum attribute.
394+
"""
395+
396+
def __init__(self) -> None:
397+
self.message = (
398+
"An IntEnumLiteral can only be initialized from a FlagEnum member"
399+
)
400+
super().__init__(self.message)

numba_dpex/experimental/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .decorators import kernel
1212
from .kernel_dispatcher import KernelDispatcher
1313
from .launcher import call_kernel, call_kernel_async
14+
from .literal_intenum_type import IntEnumLiteral
1415
from .models import *
1516
from .types import KernelDispatcherType
1617

@@ -26,4 +27,10 @@ def dpex_dispatcher_const(context):
2627
return context.get_dummy_value()
2728

2829

29-
__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]
30+
__all__ = [
31+
"kernel",
32+
"call_kernel",
33+
"call_kernel_async",
34+
"IntEnumLiteral",
35+
"KernelDispatcher",
36+
]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex
6+
intends to use as Integer literal types inside the compiler type inferring
7+
infrastructure.
8+
"""
9+
from enum import IntEnum
10+
11+
12+
class FlagEnum(IntEnum):
13+
"""Helper class to distinguish IntEnum types that numba-dpex should consider
14+
as Numba Literal types.
15+
"""
16+
17+
@classmethod
18+
def basetype(cls) -> int:
19+
"""Returns an dummy int object that helps numba-dpex infer the type of
20+
an instance of a FlagEnum class.
21+
22+
Returns:
23+
int: Dummy int value
24+
"""
25+
return int(0)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Definition of a new Literal type in numba-dpex that allows treating IntEnum
6+
members as integer literals inside a JIT compiled function.
7+
"""
8+
from enum import IntEnum
9+
10+
from numba.core.pythonapi import box
11+
from numba.core.typeconv import Conversion
12+
from numba.core.types import Integer, Literal
13+
from numba.core.typing.typeof import typeof
14+
15+
from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
16+
from numba_dpex.experimental.flag_enum import FlagEnum
17+
18+
19+
class IntEnumLiteral(Literal, Integer):
20+
"""A Literal type for IntEnum objects. The type contains the original Python
21+
value of the IntEnum class in it.
22+
"""
23+
24+
# pylint: disable=W0231
25+
def __init__(self, value):
26+
self._literal_init(value)
27+
self.name = f"Literal[IntEnum]({value})"
28+
if issubclass(value, FlagEnum):
29+
basetype = typeof(value.basetype())
30+
Integer.__init__(
31+
self,
32+
name=self.name,
33+
bitwidth=basetype.bitwidth,
34+
signed=basetype.signed,
35+
)
36+
else:
37+
raise IllegalIntEnumLiteralValueError
38+
39+
def can_convert_to(self, typingctx, other) -> bool:
40+
conv = typingctx.can_convert(self.literal_type, other)
41+
if conv is not None:
42+
return max(conv, Conversion.promote)
43+
return False
44+
45+
46+
Literal.ctor_map[IntEnum] = IntEnumLiteral
47+
48+
49+
@box(IntEnumLiteral)
50+
def box_literal_integer(typ, val, c):
51+
"""Defines how a Numba representation for an IntEnumLiteral object should
52+
be converted to a PyObject* object and returned back to Python.
53+
"""
54+
val = c.context.cast(c.builder, val, typ, typ.literal_type)
55+
return c.box(typ.literal_type, val)

numba_dpex/experimental/models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,28 @@
66
numba_dpex.experimental module.
77
"""
88

9+
from llvmlite import ir as llvmir
910
from numba.core.datamodel import DataModelManager, models
11+
from numba.core.datamodel.models import PrimitiveModel
1012
from numba.core.extending import register_model
1113

1214
import numba_dpex.core.datamodel.models as dpex_core_models
1315

16+
from .literal_intenum_type import IntEnumLiteral
1417
from .types import KernelDispatcherType
1518

1619

20+
class LiteralIntEnumModel(PrimitiveModel):
21+
"""Representation of an object of LiteralIntEnum type using Numba's
22+
PrimitiveModel that can be represented natively in the target in all
23+
usage contexts.
24+
"""
25+
26+
def __init__(self, dmm, fe_type):
27+
be_type = llvmir.IntType(fe_type.bitwidth)
28+
super().__init__(dmm, fe_type, be_type)
29+
30+
1731
def _init_exp_data_model_manager() -> DataModelManager:
1832
"""Initializes a DpexExpKernelTarget-specific data model manager.
1933
@@ -28,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager:
2842
dmm = dpex_core_models.dpex_data_model_manager.copy()
2943

3044
# Register the types and data model in the DpexExpTargetContext
31-
# Add here...
45+
dmm.register(IntEnumLiteral, LiteralIntEnumModel)
3246

3347
return dmm
3448

numba_dpex/experimental/target.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
from functools import cached_property
1010

11+
from llvmlite import ir as llvmir
12+
from numba.core import types
1113
from numba.core.descriptors import TargetDescriptor
1214
from numba.core.target_extension import GPU, target_registry
15+
from numba.core.types.scalars import IntEnumClass
1316

1417
from numba_dpex.core.descriptor import DpexTargetOptions
1518
from numba_dpex.core.targets.kernel_target import (
@@ -18,6 +21,9 @@
1821
)
1922
from numba_dpex.experimental.models import exp_dmm
2023

24+
from .flag_enum import FlagEnum
25+
from .literal_intenum_type import IntEnumLiteral
26+
2127

2228
# pylint: disable=R0903
2329
class SyclDeviceExp(GPU):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
3945
are stable enough to be migrated to DpexKernelTypingContext.
4046
"""
4147

48+
def resolve_value_type(self, val):
49+
"""
50+
Return the numba type of a Python value that is being used
51+
as a runtime constant.
52+
ValueError is raised for unsupported types.
53+
"""
54+
55+
ty = super().resolve_value_type(val)
56+
57+
if isinstance(ty, IntEnumClass) and issubclass(val, FlagEnum):
58+
ty = IntEnumLiteral(val)
59+
60+
return ty
61+
62+
def resolve_getattr(self, typ, attr):
63+
"""
64+
Resolve getting the attribute *attr* (a string) on the Numba type.
65+
The attribute's type is returned, or None if resolution failed.
66+
"""
67+
ty = None
68+
69+
if isinstance(typ, IntEnumLiteral):
70+
try:
71+
attrval = getattr(typ.literal_value, attr).value
72+
ty = types.IntegerLiteral(attrval)
73+
except ValueError:
74+
pass
75+
else:
76+
ty = super().resolve_getattr(typ, attr)
77+
return ty
78+
4279

4380
# pylint: disable=W0223
4481
# FIXME: Remove the pylint disablement once we add an override for
@@ -56,6 +93,22 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
5693
super().__init__(typingctx, target)
5794
self.data_model_manager = exp_dmm
5895

96+
def get_getattr(self, typ, attr):
97+
"""
98+
Overrides the get_getattr function to provide an implementation for
99+
getattr call on an IntegerEnumLiteral type.
100+
"""
101+
102+
if isinstance(typ, IntEnumLiteral):
103+
# pylint: disable=W0613
104+
def enum_literal_getattr_imp(context, builder, typ, val, attr):
105+
enum_attr_value = getattr(typ.literal_value, attr).value
106+
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)
107+
108+
return enum_literal_getattr_imp
109+
110+
return super().get_getattr(typ, attr)
111+
59112

60113
class DpexExpKernelTarget(TargetDescriptor):
61114
"""

0 commit comments

Comments
 (0)