Skip to content

Commit a92e9b3

Browse files
author
Diptorup Deb
authored
Merge pull request #1227 from IntelPython/experimental/inteumliteral
Adds a new literal type to store IntEnum as Literal types.
2 parents f1be213 + b1ac8d6 commit a92e9b3

File tree

13 files changed

+332
-8
lines changed

13 files changed

+332
-8
lines changed

numba_dpex/core/exceptions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,15 @@ def __init__(self, extra_msg=None) -> None:
386386
if extra_msg:
387387
self.message += " due to " + extra_msg
388388
super().__init__(self.message)
389+
390+
391+
class IllegalIntEnumLiteralValueError(Exception):
392+
"""Exception raised when an IntEnumLiteral is attempted to be created from
393+
a non FlagEnum attribute.
394+
"""
395+
396+
def __init__(self) -> None:
397+
self.message = (
398+
"An IntEnumLiteral can only be initialized from a FlagEnum member"
399+
)
400+
super().__init__(self.message)

numba_dpex/experimental/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
from numba.core.imputils import Registry
1010

11-
from .decorators import kernel
11+
from .decorators import device_func, kernel
1212
from .kernel_dispatcher import KernelDispatcher
1313
from .launcher import call_kernel, call_kernel_async
14+
from .literal_intenum_type import IntEnumLiteral
1415
from .models import *
1516
from .types import KernelDispatcherType
1617

@@ -26,4 +27,11 @@ def dpex_dispatcher_const(context):
2627
return context.get_dummy_value()
2728

2829

29-
__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]
30+
__all__ = [
31+
"device_func",
32+
"kernel",
33+
"call_kernel",
34+
"call_kernel_async",
35+
"IntEnumLiteral",
36+
"KernelDispatcher",
37+
]

numba_dpex/experimental/flag_enum.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex
6+
intends to use as Integer literal types inside the compiler type inferring
7+
infrastructure.
8+
"""
9+
from enum import IntEnum
10+
11+
12+
class FlagEnum(IntEnum):
13+
"""Helper class to distinguish IntEnum types that numba-dpex should consider
14+
as Numba Literal types.
15+
"""
16+
17+
@classmethod
18+
def basetype(cls) -> int:
19+
"""Returns an dummy int object that helps numba-dpex infer the type of
20+
an instance of a FlagEnum class.
21+
22+
Returns:
23+
int: Dummy int value
24+
"""
25+
return int(0)

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,12 @@ def add_overload(self, cres):
262262
args = tuple(cres.signature.args)
263263
self.overloads[args] = cres
264264

265-
def get_overload_device_ir(self, sig):
265+
def get_overload_kcres(self, sig) -> _KernelCompileResult:
266266
"""
267-
Return the compiled device bitcode for the given signature.
267+
Return the compiled function for the given signature.
268268
"""
269269
args, _ = sigutils.normalize_signature(sig)
270-
return self.overloads[tuple(args)].kernel_device_ir_module
270+
return self.overloads[tuple(args)]
271271

272272
def compile(self, sig) -> any:
273273
disp = self._get_dispatcher_for_current_target()

numba_dpex/experimental/launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ def _submit_kernel(
303303
# codegen
304304
kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher
305305
kernel_dispatcher.compile(kernel_sig)
306-
kernel_module: _KernelModule = kernel_dispatcher.get_overload_device_ir(
306+
kernel_module: _KernelModule = kernel_dispatcher.get_overload_kcres(
307307
kernel_sig
308-
)
308+
).kernel_device_ir_module
309309
kernel_targetctx = kernel_dispatcher.targetctx
310310

311311
def codegen(cgctx, builder, sig, llargs):
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Definition of a new Literal type in numba-dpex that allows treating IntEnum
6+
members as integer literals inside a JIT compiled function.
7+
"""
8+
from enum import IntEnum
9+
10+
from numba.core.pythonapi import box
11+
from numba.core.typeconv import Conversion
12+
from numba.core.types import Integer, Literal
13+
from numba.core.typing.typeof import typeof
14+
15+
from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
16+
from numba_dpex.experimental.flag_enum import FlagEnum
17+
18+
19+
class IntEnumLiteral(Literal, Integer):
20+
"""A Literal type for IntEnum objects. The type contains the original Python
21+
value of the IntEnum class in it.
22+
"""
23+
24+
# pylint: disable=W0231
25+
def __init__(self, value):
26+
self._literal_init(value)
27+
self.name = f"Literal[IntEnum]({value})"
28+
if issubclass(value, FlagEnum):
29+
basetype = typeof(value.basetype())
30+
Integer.__init__(
31+
self,
32+
name=self.name,
33+
bitwidth=basetype.bitwidth,
34+
signed=basetype.signed,
35+
)
36+
else:
37+
raise IllegalIntEnumLiteralValueError
38+
39+
def can_convert_to(self, typingctx, other) -> bool:
40+
conv = typingctx.can_convert(self.literal_type, other)
41+
if conv is not None:
42+
return max(conv, Conversion.promote)
43+
return False
44+
45+
46+
Literal.ctor_map[IntEnum] = IntEnumLiteral
47+
48+
49+
@box(IntEnumLiteral)
50+
def box_literal_integer(typ, val, c):
51+
"""Defines how a Numba representation for an IntEnumLiteral object should
52+
be converted to a PyObject* object and returned back to Python.
53+
"""
54+
val = c.context.cast(c.builder, val, typ, typ.literal_type)
55+
return c.box(typ.literal_type, val)

numba_dpex/experimental/models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,28 @@
66
numba_dpex.experimental module.
77
"""
88

9+
from llvmlite import ir as llvmir
910
from numba.core.datamodel import DataModelManager, models
11+
from numba.core.datamodel.models import PrimitiveModel
1012
from numba.core.extending import register_model
1113

1214
import numba_dpex.core.datamodel.models as dpex_core_models
1315

16+
from .literal_intenum_type import IntEnumLiteral
1417
from .types import KernelDispatcherType
1518

1619

20+
class IntEnumLiteralModel(PrimitiveModel):
21+
"""Representation of an object of LiteralIntEnum type using Numba's
22+
PrimitiveModel that can be represented natively in the target in all
23+
usage contexts.
24+
"""
25+
26+
def __init__(self, dmm, fe_type):
27+
be_type = llvmir.IntType(fe_type.bitwidth)
28+
super().__init__(dmm, fe_type, be_type)
29+
30+
1731
def _init_exp_data_model_manager() -> DataModelManager:
1832
"""Initializes a DpexExpKernelTarget-specific data model manager.
1933
@@ -28,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager:
2842
dmm = dpex_core_models.dpex_data_model_manager.copy()
2943

3044
# Register the types and data model in the DpexExpTargetContext
31-
# Add here...
45+
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
3246

3347
return dmm
3448

numba_dpex/experimental/target.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
from functools import cached_property
1010

11+
from llvmlite import ir as llvmir
12+
from numba.core import types
1113
from numba.core.descriptors import TargetDescriptor
1214
from numba.core.target_extension import GPU, target_registry
15+
from numba.core.types.scalars import IntEnumClass
1316

1417
from numba_dpex.core.descriptor import DpexTargetOptions
1518
from numba_dpex.core.targets.kernel_target import (
@@ -18,6 +21,9 @@
1821
)
1922
from numba_dpex.experimental.models import exp_dmm
2023

24+
from .flag_enum import FlagEnum
25+
from .literal_intenum_type import IntEnumLiteral
26+
2127

2228
# pylint: disable=R0903
2329
class SyclDeviceExp(GPU):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
3945
are stable enough to be migrated to DpexKernelTypingContext.
4046
"""
4147

48+
def resolve_value_type(self, val):
49+
"""
50+
Return the numba type of a Python value that is being used
51+
as a runtime constant.
52+
ValueError is raised for unsupported types.
53+
"""
54+
55+
ty = super().resolve_value_type(val)
56+
57+
if isinstance(ty, IntEnumClass) and issubclass(val, FlagEnum):
58+
ty = IntEnumLiteral(val)
59+
60+
return ty
61+
62+
def resolve_getattr(self, typ, attr):
63+
"""
64+
Resolve getting the attribute *attr* (a string) on the Numba type.
65+
The attribute's type is returned, or None if resolution failed.
66+
"""
67+
ty = None
68+
69+
if isinstance(typ, IntEnumLiteral):
70+
try:
71+
attrval = getattr(typ.literal_value, attr).value
72+
ty = types.IntegerLiteral(attrval)
73+
except ValueError:
74+
pass
75+
else:
76+
ty = super().resolve_getattr(typ, attr)
77+
return ty
78+
4279

4380
# pylint: disable=W0223
4481
# FIXME: Remove the pylint disablement once we add an override for
@@ -52,10 +89,28 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
5289
they are stable enough to be migrated to DpexKernelTargetContext.
5390
"""
5491

92+
allow_dynamic_globals = True
93+
5594
def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
5695
super().__init__(typingctx, target)
5796
self.data_model_manager = exp_dmm
5897

98+
def get_getattr(self, typ, attr):
99+
"""
100+
Overrides the get_getattr function to provide an implementation for
101+
getattr call on an IntegerEnumLiteral type.
102+
"""
103+
104+
if isinstance(typ, IntEnumLiteral):
105+
# pylint: disable=W0613
106+
def enum_literal_getattr_imp(context, builder, typ, val, attr):
107+
enum_attr_value = getattr(typ.literal_value, attr).value
108+
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)
109+
110+
return enum_literal_getattr_imp
111+
112+
return super().get_getattr(typ, attr)
113+
59114

60115
class DpexExpKernelTarget(TargetDescriptor):
61116
"""

numba_dpex/tests/experimental/IntEnumLiteral/__init__.py

Whitespace-only changes.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
7+
import numba_dpex.experimental as exp_dpex
8+
from numba_dpex import Range
9+
from numba_dpex.experimental.flag_enum import FlagEnum
10+
11+
12+
class MockFlags(FlagEnum):
13+
FLAG1 = 100
14+
FLAG2 = 200
15+
16+
17+
@exp_dpex.kernel(
18+
release_gil=False,
19+
no_compile=True,
20+
no_cpython_wrapper=True,
21+
no_cfunc_wrapper=True,
22+
)
23+
def update_with_flag(a):
24+
a[0] = MockFlags.FLAG1
25+
a[1] = MockFlags.FLAG2
26+
27+
28+
def test_compilation_of_flag_enum():
29+
"""Tests if a FlagEnum subclass can be used inside a kernel function."""
30+
a = dpnp.ones(10, dtype=dpnp.int64)
31+
exp_dpex.call_kernel(update_with_flag, Range(10), a)
32+
33+
assert a[0] == MockFlags.FLAG1
34+
assert a[1] == MockFlags.FLAG2
35+
for idx in range(2, 9):
36+
assert a[idx] == 1

0 commit comments

Comments
 (0)