8
8
9
9
from functools import cached_property
10
10
11
+ from llvmlite import ir as llvmir
12
+ from numba .core import types
11
13
from numba .core .descriptors import TargetDescriptor
12
14
from numba .core .target_extension import GPU , target_registry
15
+ from numba .core .types .scalars import IntEnumClass
13
16
14
17
from numba_dpex .core .descriptor import DpexTargetOptions
15
18
from numba_dpex .core .targets .kernel_target import (
18
21
)
19
22
from numba_dpex .experimental .models import exp_dmm
20
23
24
+ from .flag_enum import FlagEnum
25
+ from .literal_intenum_type import IntEnumLiteral
26
+
21
27
22
28
# pylint: disable=R0903
23
29
class SyclDeviceExp (GPU ):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
39
45
are stable enough to be migrated to DpexKernelTypingContext.
40
46
"""
41
47
48
+ def resolve_value_type (self , val ):
49
+ """
50
+ Return the numba type of a Python value that is being used
51
+ as a runtime constant.
52
+ ValueError is raised for unsupported types.
53
+ """
54
+
55
+ ty = super ().resolve_value_type (val )
56
+
57
+ if isinstance (ty , IntEnumClass ) and issubclass (val , FlagEnum ):
58
+ ty = IntEnumLiteral (val )
59
+
60
+ return ty
61
+
62
+ def resolve_getattr (self , typ , attr ):
63
+ """
64
+ Resolve getting the attribute *attr* (a string) on the Numba type.
65
+ The attribute's type is returned, or None if resolution failed.
66
+ """
67
+ ty = None
68
+
69
+ if isinstance (typ , IntEnumLiteral ):
70
+ try :
71
+ attrval = getattr (typ .literal_value , attr ).value
72
+ ty = types .IntegerLiteral (attrval )
73
+ except ValueError :
74
+ pass
75
+ else :
76
+ ty = super ().resolve_getattr (typ , attr )
77
+ return ty
78
+
42
79
43
80
# pylint: disable=W0223
44
81
# FIXME: Remove the pylint disablement once we add an override for
@@ -56,6 +93,22 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
56
93
super ().__init__ (typingctx , target )
57
94
self .data_model_manager = exp_dmm
58
95
96
+ def get_getattr (self , typ , attr ):
97
+ """
98
+ Overrides the get_getattr function to provide an implementation for
99
+ getattr call on an IntegerEnumLiteral type.
100
+ """
101
+
102
+ if isinstance (typ , IntEnumLiteral ):
103
+ # pylint: disable=W0613
104
+ def enum_literal_getattr_imp (context , builder , typ , val , attr ):
105
+ enum_attr_value = getattr (typ .literal_value , attr ).value
106
+ return llvmir .Constant (llvmir .IntType (64 ), enum_attr_value )
107
+
108
+ return enum_literal_getattr_imp
109
+
110
+ return super ().get_getattr (typ , attr )
111
+
59
112
60
113
class DpexExpKernelTarget (TargetDescriptor ):
61
114
"""
0 commit comments