|
6 | 6 | Implements the SPIR-V overloads for the kernel_api.items class methods.
|
7 | 7 | """
|
8 | 8 |
|
| 9 | +import llvmlite.ir as llvmir |
9 | 10 | from numba.core import cgutils, types
|
10 | 11 | from numba.core.errors import TypingError
|
11 | 12 | from numba.extending import intrinsic, overload_method
|
12 | 13 |
|
| 14 | +from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext |
13 | 15 | from numba_dpex.experimental.core.types.kernel_api.items import (
|
14 | 16 | GroupType,
|
15 | 17 | ItemType,
|
16 | 18 | NdItemType,
|
17 | 19 | )
|
18 |
| -from numba_dpex.ocl._declare_function import _declare_function |
19 | 20 |
|
20 | 21 | from ..target import DPEX_KERNEL_EXP_TARGET_NAME
|
21 | 22 |
|
22 | 23 |
|
23 |
| -@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
24 |
| -def _intrinsic_get_global_id( |
25 |
| - ty_context, ty_dim # pylint: disable=unused-argument |
| 24 | +def spirv_name(name: str): |
| 25 | + """Converts name to spirv name by adding __spirv_ prefix.""" |
| 26 | + return "__spirv_" + name |
| 27 | + |
| 28 | + |
| 29 | +def declare_spirv_const( |
| 30 | + builder: llvmir.IRBuilder, |
| 31 | + name: str, |
| 32 | +): |
| 33 | + """Declares global external spirv constant""" |
| 34 | + data = cgutils.add_global_variable( |
| 35 | + builder.module, |
| 36 | + llvmir.VectorType(llvmir.IntType(64), 3), |
| 37 | + spirv_name(name), |
| 38 | + addrspace=1, |
| 39 | + ) |
| 40 | + data.linkage = "external" |
| 41 | + data.global_constant = True |
| 42 | + data.align = 32 |
| 43 | + data.storage_class = "dso_local local_unnamed_addr" |
| 44 | + return data |
| 45 | + |
| 46 | + |
| 47 | +def _intrinsic_spirv_global_index_const( |
| 48 | + ty_context, # pylint: disable=unused-argument |
| 49 | + ty_dim, # pylint: disable=unused-argument |
| 50 | + const_name: str, |
26 | 51 | ):
|
27 |
| - """Generates instruction for spirv i64 get_global_id(i32) call.""" |
| 52 | + """Generates instruction to get spirv index from const_name.""" |
28 | 53 | sig = types.int64(types.int32)
|
29 | 54 |
|
30 |
| - def _intrinsic_get_global_id_gen(context, builder, sig, args): |
31 |
| - [dim] = args |
32 |
| - get_global_id = _declare_function( |
33 |
| - # unsigned int - is what demangler returns from IR instruction |
34 |
| - # generated by dpcpp. However int32 is passed as an argument. |
35 |
| - # Most likely it does not matter, since argument can be from 0 to 3. |
36 |
| - # TODO: https://github.com/IntelPython/numba-dpex/issues/936 |
37 |
| - # Numba generates unnecessary checks because of the type mismatch. |
38 |
| - context, |
| 55 | + def _intrinsic_spirv_global_index_const_gen( |
| 56 | + context: SPIRVTargetContext, |
| 57 | + builder: llvmir.IRBuilder, |
| 58 | + sig, # pylint: disable=unused-argument |
| 59 | + args, |
| 60 | + ): |
| 61 | + index_const = declare_spirv_const( |
39 | 62 | builder,
|
40 |
| - "get_global_id", |
41 |
| - sig, |
42 |
| - ["unsigned int"], |
| 63 | + const_name, |
| 64 | + ) |
| 65 | + [dim] = args |
| 66 | + # TODO: llvmlite does not support gep on vector. Use this instead once |
| 67 | + # supported. |
| 68 | + # https://github.com/numba/llvmlite/issues/756 |
| 69 | + # res = builder.gep( # noqa: E800 |
| 70 | + # global_invocation_id, # noqa: E800 |
| 71 | + # [cgutils.int32_t(0), cgutils.int32_t(0)], # noqa: E800 |
| 72 | + # inbounds=True, # noqa: E800 |
| 73 | + # ) # noqa: E800 |
| 74 | + # res = builder.load(res, align=32) # noqa: E800 |
| 75 | + |
| 76 | + res = builder.extract_element( |
| 77 | + builder.load(index_const), |
| 78 | + dim, |
43 | 79 | )
|
44 |
| - res = builder.call(get_global_id, [dim]) |
| 80 | + |
45 | 81 | return context.cast(builder, res, types.uintp, types.intp)
|
46 | 82 |
|
47 |
| - return sig, _intrinsic_get_global_id_gen |
| 83 | + return sig, _intrinsic_spirv_global_index_const_gen |
48 | 84 |
|
49 | 85 |
|
50 | 86 | @intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
|
51 |
| -def _intrinsic_get_group( |
52 |
| - ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument |
| 87 | +def _intrinsic_spirv_global_invocation_id( |
| 88 | + ty_context, ty_dim # pylint: disable=unused-argument |
53 | 89 | ):
|
54 |
| - """Generates group with a dimension of nd_item.""" |
| 90 | + """Generates instruction to get index from BuiltInGlobalInvocationId.""" |
| 91 | + return _intrinsic_spirv_global_index_const( |
| 92 | + ty_context, ty_dim, "BuiltInGlobalInvocationId" |
| 93 | + ) |
55 | 94 |
|
56 |
| - if not isinstance(ty_nd_item, NdItemType): |
57 |
| - raise TypingError( |
58 |
| - f"Expected an NdItemType value, but encountered {ty_nd_item}" |
59 |
| - ) |
60 | 95 |
|
61 |
| - ty_group = GroupType(ty_nd_item.ndim) |
62 |
| - sig = ty_group(ty_nd_item) |
| 96 | +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
| 97 | +def _intrinsic_spirv_local_invocation_id( |
| 98 | + ty_context, ty_dim # pylint: disable=unused-argument |
| 99 | +): |
| 100 | + """Generates instruction to get index from BuiltInLocalInvocationId.""" |
| 101 | + return _intrinsic_spirv_global_index_const( |
| 102 | + ty_context, ty_dim, "BuiltInLocalInvocationId" |
| 103 | + ) |
63 | 104 |
|
64 |
| - # pylint: disable=unused-argument |
65 |
| - def _intrinsic_get_group_gen(context, builder, sig, args): |
66 |
| - group_struct = cgutils.create_struct_proxy(ty_group)(context, builder) |
67 |
| - # pylint: disable=protected-access |
68 |
| - return group_struct._getvalue() |
69 | 105 |
|
70 |
| - return sig, _intrinsic_get_group_gen |
| 106 | +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
| 107 | +def _intrinsic_spirv_global_size( |
| 108 | + ty_context, ty_dim # pylint: disable=unused-argument |
| 109 | +): |
| 110 | + """Generates instruction to get index from BuiltInGlobalSize.""" |
| 111 | + return _intrinsic_spirv_global_index_const( |
| 112 | + ty_context, ty_dim, "BuiltInGlobalSize" |
| 113 | + ) |
71 | 114 |
|
72 | 115 |
|
73 |
| -@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME) |
74 |
| -def ol_item_get_id(item, dim): |
75 |
| - """SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`. |
| 116 | +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
| 117 | +def _intrinsic_spirv_workgroup_size( |
| 118 | + ty_context, ty_dim # pylint: disable=unused-argument |
| 119 | +): |
| 120 | + """Generates instruction to get index from BuiltInWorkgroupSize.""" |
| 121 | + return _intrinsic_spirv_global_index_const( |
| 122 | + ty_context, ty_dim, "BuiltInWorkgroupSize" |
| 123 | + ) |
76 | 124 |
|
77 |
| - Generates the same LLVM IR instruction as dpcpp for the |
78 |
| - `sycl::item::get_id` function. |
79 | 125 |
|
80 |
| - Raises: |
81 |
| - TypingError: When argument is not an integer. |
82 |
| - """ |
83 |
| - if not isinstance(item, ItemType): |
84 |
| - raise TypingError( |
85 |
| - "Expected an item should to be an Item value, but " |
86 |
| - f"encountered {type(item)}" |
87 |
| - ) |
| 126 | +def generate_index_overload(_type, _intrinsic): |
| 127 | + """Generates overload for the index method that generates specific IR from |
| 128 | + provided intrinsic.""" |
88 | 129 |
|
89 |
| - if not isinstance(dim, types.Integer): |
90 |
| - raise TypingError( |
91 |
| - "Expected an Item's dim should to be an Integer value, but " |
92 |
| - f"encountered {type(dim)}" |
93 |
| - ) |
| 130 | + def ol_item_gen_index(item, dim): |
| 131 | + """SPIR-V overload for :meth:`numba_dpex.kernel_api.<_type>.<method>`. |
94 | 132 |
|
95 |
| - # pylint: disable=unused-argument |
96 |
| - def ol_item_get_id_impl(item, dim): |
97 |
| - # pylint: disable=no-value-for-parameter |
98 |
| - return _intrinsic_get_global_id(dim) |
| 133 | + Generates the same LLVM IR instruction as dpcpp for the |
| 134 | + `sycl::<type>::<method>` function. |
99 | 135 |
|
100 |
| - return ol_item_get_id_impl |
| 136 | + Raises: |
| 137 | + TypingError: When argument is not an integer. |
| 138 | + """ |
| 139 | + if not isinstance(item, _type): |
| 140 | + raise TypingError( |
| 141 | + f"Expected an item should to be an {_type} value, but " |
| 142 | + f"encountered {type(item)}" |
| 143 | + ) |
101 | 144 |
|
| 145 | + if not isinstance(dim, types.Integer): |
| 146 | + raise TypingError( |
| 147 | + f"Expected an {_type}'s dim should to be an Integer value, but " |
| 148 | + f"encountered {type(dim)}" |
| 149 | + ) |
102 | 150 |
|
103 |
| -@overload_method( |
104 |
| - NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME |
105 |
| -) |
106 |
| -def ol_nd_item_get_global_id(nd_item, dim): |
107 |
| - """SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`. |
| 151 | + # pylint: disable=unused-argument |
| 152 | + def ol_item_get_index_impl(item, dim): |
| 153 | + # TODO: call in reverse index once index reversing is removed from |
| 154 | + # kernel submission |
| 155 | + # pylint: disable=no-value-for-parameter |
| 156 | + return _intrinsic(dim) |
108 | 157 |
|
109 |
| - Generates the same LLVM IR instruction as dpcpp for the |
110 |
| - `sycl::nd_item::get_global_id` function. |
| 158 | + return ol_item_get_index_impl |
111 | 159 |
|
112 |
| - Raises: |
113 |
| - TypingError: When argument is not an integer. |
114 |
| - """ |
115 |
| - if not isinstance(nd_item, NdItemType): |
116 |
| - # since it is a method overload, this error should not be reached |
117 |
| - raise TypingError( |
118 |
| - "Expected a nd_item should to be a NdItem value, but " |
119 |
| - f"encountered {type(nd_item)}" |
120 |
| - ) |
| 160 | + return ol_item_gen_index |
| 161 | + |
| 162 | + |
| 163 | +_index_const_overload_methods = [ |
| 164 | + (ItemType, "get_id", _intrinsic_spirv_global_invocation_id), |
| 165 | + (ItemType, "get_range", _intrinsic_spirv_global_size), |
| 166 | + (NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id), |
| 167 | + (NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id), |
| 168 | + (NdItemType, "get_global_range", _intrinsic_spirv_global_size), |
| 169 | + (NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size), |
| 170 | +] |
| 171 | + |
| 172 | +for index_overload in _index_const_overload_methods: |
| 173 | + _type, method, _intrinsic = index_overload |
| 174 | + |
| 175 | + ol_index_func = generate_index_overload(_type, _intrinsic) |
121 | 176 |
|
122 |
| - if not isinstance(dim, types.Integer): |
| 177 | + overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)( |
| 178 | + ol_index_func |
| 179 | + ) |
| 180 | + |
| 181 | + |
| 182 | +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
| 183 | +def _intrinsic_get_group( |
| 184 | + ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument |
| 185 | +): |
| 186 | + """Generates group with a dimension of nd_item.""" |
| 187 | + |
| 188 | + if not isinstance(ty_nd_item, NdItemType): |
123 | 189 | raise TypingError(
|
124 |
| - "Expected a NdItem's dim should to be an Integer value, but " |
125 |
| - f"encountered {type(dim)}" |
| 190 | + f"Expected an NdItemType value, but encountered {ty_nd_item}" |
126 | 191 | )
|
127 | 192 |
|
| 193 | + ty_group = GroupType(ty_nd_item.ndim) |
| 194 | + sig = ty_group(ty_nd_item) |
| 195 | + |
128 | 196 | # pylint: disable=unused-argument
|
129 |
| - def ol_nd_item_get_global_id_impl(nd_item, dim): |
130 |
| - # pylint: disable=no-value-for-parameter |
131 |
| - return _intrinsic_get_global_id(dim) |
| 197 | + def _intrinsic_get_group_gen(context, builder, sig, args): |
| 198 | + group_struct = cgutils.create_struct_proxy(ty_group)(context, builder) |
| 199 | + # pylint: disable=protected-access |
| 200 | + return group_struct._getvalue() |
132 | 201 |
|
133 |
| - return ol_nd_item_get_global_id_impl |
| 202 | + return sig, _intrinsic_get_group_gen |
134 | 203 |
|
135 | 204 |
|
136 | 205 | @overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
|
|
0 commit comments