8
8
9
9
from functools import cached_property
10
10
11
+ from llvmlite import ir as llvmir
12
+ from numba .core import types
11
13
from numba .core .descriptors import TargetDescriptor
12
14
from numba .core .target_extension import GPU , target_registry
15
+ from numba .core .types .scalars import IntEnumClass
13
16
14
17
from numba_dpex .core .descriptor import DpexTargetOptions
15
18
from numba_dpex .core .targets .kernel_target import (
18
21
)
19
22
from numba_dpex .experimental .models import exp_dmm
20
23
24
+ from .flag_enum import FlagEnum
25
+ from .literal_intenum_type import IntEnumLiteral
26
+
21
27
22
28
# pylint: disable=R0903
23
29
class SyclDeviceExp (GPU ):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
39
45
are stable enough to be migrated to DpexKernelTypingContext.
40
46
"""
41
47
48
+ def resolve_value_type (self , val ):
49
+ """
50
+ Return the numba type of a Python value that is being used
51
+ as a runtime constant.
52
+ ValueError is raised for unsupported types.
53
+ """
54
+
55
+ ty = super ().resolve_value_type (val )
56
+
57
+ if isinstance (ty , IntEnumClass ) and issubclass (val , FlagEnum ):
58
+ ty = IntEnumLiteral (val )
59
+
60
+ return ty
61
+
62
+ def resolve_getattr (self , typ , attr ):
63
+ """
64
+ Resolve getting the attribute *attr* (a string) on the Numba type.
65
+ The attribute's type is returned, or None if resolution failed.
66
+ """
67
+ ty = None
68
+
69
+ if isinstance (typ , IntEnumLiteral ):
70
+ try :
71
+ attrval = getattr (typ .literal_value , attr ).value
72
+ ty = types .IntegerLiteral (attrval )
73
+ except ValueError :
74
+ pass
75
+ else :
76
+ ty = super ().resolve_getattr (typ , attr )
77
+ return ty
78
+
42
79
43
80
# pylint: disable=W0223
44
81
# FIXME: Remove the pylint disablement once we add an override for
@@ -52,10 +89,28 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
52
89
they are stable enough to be migrated to DpexKernelTargetContext.
53
90
"""
54
91
92
+ allow_dynamic_globals = True
93
+
55
94
def __init__ (self , typingctx , target = DPEX_KERNEL_EXP_TARGET_NAME ):
56
95
super ().__init__ (typingctx , target )
57
96
self .data_model_manager = exp_dmm
58
97
98
+ def get_getattr (self , typ , attr ):
99
+ """
100
+ Overrides the get_getattr function to provide an implementation for
101
+ getattr call on an IntegerEnumLiteral type.
102
+ """
103
+
104
+ if isinstance (typ , IntEnumLiteral ):
105
+ # pylint: disable=W0613
106
+ def enum_literal_getattr_imp (context , builder , typ , val , attr ):
107
+ enum_attr_value = getattr (typ .literal_value , attr ).value
108
+ return llvmir .Constant (llvmir .IntType (64 ), enum_attr_value )
109
+
110
+ return enum_literal_getattr_imp
111
+
112
+ return super ().get_getattr (typ , attr )
113
+
59
114
60
115
class DpexExpKernelTarget (TargetDescriptor ):
61
116
"""
0 commit comments