Skip to content

Commit d4118df

Browse files
authored
Merge pull request #1323 from IntelPython/feature/index_function_overloads
Feature/index function overloads
2 parents 8a633c4 + 31e1205 commit d4118df

File tree

3 files changed

+290
-82
lines changed

3 files changed

+290
-82
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 149 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,131 +6,200 @@
66
Implements the SPIR-V overloads for the kernel_api.items class methods.
77
"""
88

9+
import llvmlite.ir as llvmir
910
from numba.core import cgutils, types
1011
from numba.core.errors import TypingError
1112
from numba.extending import intrinsic, overload_method
1213

14+
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
1315
from numba_dpex.experimental.core.types.kernel_api.items import (
1416
GroupType,
1517
ItemType,
1618
NdItemType,
1719
)
18-
from numba_dpex.ocl._declare_function import _declare_function
1920

2021
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
2122

2223

23-
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
24-
def _intrinsic_get_global_id(
25-
ty_context, ty_dim # pylint: disable=unused-argument
24+
def spirv_name(name: str):
25+
"""Converts name to spirv name by adding __spirv_ prefix."""
26+
return "__spirv_" + name
27+
28+
29+
def declare_spirv_const(
30+
builder: llvmir.IRBuilder,
31+
name: str,
32+
):
33+
"""Declares global external spirv constant"""
34+
data = cgutils.add_global_variable(
35+
builder.module,
36+
llvmir.VectorType(llvmir.IntType(64), 3),
37+
spirv_name(name),
38+
addrspace=1,
39+
)
40+
data.linkage = "external"
41+
data.global_constant = True
42+
data.align = 32
43+
data.storage_class = "dso_local local_unnamed_addr"
44+
return data
45+
46+
47+
def _intrinsic_spirv_global_index_const(
48+
ty_context, # pylint: disable=unused-argument
49+
ty_dim, # pylint: disable=unused-argument
50+
const_name: str,
2651
):
27-
"""Generates instruction for spirv i64 get_global_id(i32) call."""
52+
"""Generates instruction to get spirv index from const_name."""
2853
sig = types.int64(types.int32)
2954

30-
def _intrinsic_get_global_id_gen(context, builder, sig, args):
31-
[dim] = args
32-
get_global_id = _declare_function(
33-
# unsigned int - is what demangler returns from IR instruction
34-
# generated by dpcpp. However int32 is passed as an argument.
35-
# Most likely it does not matter, since argument can be from 0 to 3.
36-
# TODO: https://github.com/IntelPython/numba-dpex/issues/936
37-
# Numba generates unnecessary checks because of the type mismatch.
38-
context,
55+
def _intrinsic_spirv_global_index_const_gen(
56+
context: SPIRVTargetContext,
57+
builder: llvmir.IRBuilder,
58+
sig, # pylint: disable=unused-argument
59+
args,
60+
):
61+
index_const = declare_spirv_const(
3962
builder,
40-
"get_global_id",
41-
sig,
42-
["unsigned int"],
63+
const_name,
64+
)
65+
[dim] = args
66+
# TODO: llvmlite does not support gep on vector. Use this instead once
67+
# supported.
68+
# https://github.com/numba/llvmlite/issues/756
69+
# res = builder.gep( # noqa: E800
70+
# global_invocation_id, # noqa: E800
71+
# [cgutils.int32_t(0), cgutils.int32_t(0)], # noqa: E800
72+
# inbounds=True, # noqa: E800
73+
# ) # noqa: E800
74+
# res = builder.load(res, align=32) # noqa: E800
75+
76+
res = builder.extract_element(
77+
builder.load(index_const),
78+
dim,
4379
)
44-
res = builder.call(get_global_id, [dim])
80+
4581
return context.cast(builder, res, types.uintp, types.intp)
4682

47-
return sig, _intrinsic_get_global_id_gen
83+
return sig, _intrinsic_spirv_global_index_const_gen
4884

4985

5086
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
51-
def _intrinsic_get_group(
52-
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
87+
def _intrinsic_spirv_global_invocation_id(
88+
ty_context, ty_dim # pylint: disable=unused-argument
5389
):
54-
"""Generates group with a dimension of nd_item."""
90+
"""Generates instruction to get index from BuiltInGlobalInvocationId."""
91+
return _intrinsic_spirv_global_index_const(
92+
ty_context, ty_dim, "BuiltInGlobalInvocationId"
93+
)
5594

56-
if not isinstance(ty_nd_item, NdItemType):
57-
raise TypingError(
58-
f"Expected an NdItemType value, but encountered {ty_nd_item}"
59-
)
6095

61-
ty_group = GroupType(ty_nd_item.ndim)
62-
sig = ty_group(ty_nd_item)
96+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
97+
def _intrinsic_spirv_local_invocation_id(
98+
ty_context, ty_dim # pylint: disable=unused-argument
99+
):
100+
"""Generates instruction to get index from BuiltInLocalInvocationId."""
101+
return _intrinsic_spirv_global_index_const(
102+
ty_context, ty_dim, "BuiltInLocalInvocationId"
103+
)
63104

64-
# pylint: disable=unused-argument
65-
def _intrinsic_get_group_gen(context, builder, sig, args):
66-
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
67-
# pylint: disable=protected-access
68-
return group_struct._getvalue()
69105

70-
return sig, _intrinsic_get_group_gen
106+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
107+
def _intrinsic_spirv_global_size(
108+
ty_context, ty_dim # pylint: disable=unused-argument
109+
):
110+
"""Generates instruction to get index from BuiltInGlobalSize."""
111+
return _intrinsic_spirv_global_index_const(
112+
ty_context, ty_dim, "BuiltInGlobalSize"
113+
)
71114

72115

73-
@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
74-
def ol_item_get_id(item, dim):
75-
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.
116+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
117+
def _intrinsic_spirv_workgroup_size(
118+
ty_context, ty_dim # pylint: disable=unused-argument
119+
):
120+
"""Generates instruction to get index from BuiltInWorkgroupSize."""
121+
return _intrinsic_spirv_global_index_const(
122+
ty_context, ty_dim, "BuiltInWorkgroupSize"
123+
)
76124

77-
Generates the same LLVM IR instruction as dpcpp for the
78-
`sycl::item::get_id` function.
79125

80-
Raises:
81-
TypingError: When argument is not an integer.
82-
"""
83-
if not isinstance(item, ItemType):
84-
raise TypingError(
85-
"Expected an item should to be an Item value, but "
86-
f"encountered {type(item)}"
87-
)
126+
def generate_index_overload(_type, _intrinsic):
127+
"""Generates overload for the index method that generates specific IR from
128+
provided intrinsic."""
88129

89-
if not isinstance(dim, types.Integer):
90-
raise TypingError(
91-
"Expected an Item's dim should to be an Integer value, but "
92-
f"encountered {type(dim)}"
93-
)
130+
def ol_item_gen_index(item, dim):
131+
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.<_type>.<method>`.
94132
95-
# pylint: disable=unused-argument
96-
def ol_item_get_id_impl(item, dim):
97-
# pylint: disable=no-value-for-parameter
98-
return _intrinsic_get_global_id(dim)
133+
Generates the same LLVM IR instruction as dpcpp for the
134+
`sycl::<type>::<method>` function.
99135
100-
return ol_item_get_id_impl
136+
Raises:
137+
TypingError: When argument is not an integer.
138+
"""
139+
if not isinstance(item, _type):
140+
raise TypingError(
141+
f"Expected an item should to be an {_type} value, but "
142+
f"encountered {type(item)}"
143+
)
101144

145+
if not isinstance(dim, types.Integer):
146+
raise TypingError(
147+
f"Expected an {_type}'s dim should to be an Integer value, but "
148+
f"encountered {type(dim)}"
149+
)
102150

103-
@overload_method(
104-
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
105-
)
106-
def ol_nd_item_get_global_id(nd_item, dim):
107-
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
151+
# pylint: disable=unused-argument
152+
def ol_item_get_index_impl(item, dim):
153+
# TODO: call in reverse index once index reversing is removed from
154+
# kernel submission
155+
# pylint: disable=no-value-for-parameter
156+
return _intrinsic(dim)
108157

109-
Generates the same LLVM IR instruction as dpcpp for the
110-
`sycl::nd_item::get_global_id` function.
158+
return ol_item_get_index_impl
111159

112-
Raises:
113-
TypingError: When argument is not an integer.
114-
"""
115-
if not isinstance(nd_item, NdItemType):
116-
# since it is a method overload, this error should not be reached
117-
raise TypingError(
118-
"Expected a nd_item should to be a NdItem value, but "
119-
f"encountered {type(nd_item)}"
120-
)
160+
return ol_item_gen_index
161+
162+
163+
_index_const_overload_methods = [
164+
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
165+
(ItemType, "get_range", _intrinsic_spirv_global_size),
166+
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
167+
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
168+
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
169+
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
170+
]
171+
172+
for index_overload in _index_const_overload_methods:
173+
_type, method, _intrinsic = index_overload
174+
175+
ol_index_func = generate_index_overload(_type, _intrinsic)
121176

122-
if not isinstance(dim, types.Integer):
177+
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
178+
ol_index_func
179+
)
180+
181+
182+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
183+
def _intrinsic_get_group(
184+
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
185+
):
186+
"""Generates group with a dimension of nd_item."""
187+
188+
if not isinstance(ty_nd_item, NdItemType):
123189
raise TypingError(
124-
"Expected a NdItem's dim should to be an Integer value, but "
125-
f"encountered {type(dim)}"
190+
f"Expected an NdItemType value, but encountered {ty_nd_item}"
126191
)
127192

193+
ty_group = GroupType(ty_nd_item.ndim)
194+
sig = ty_group(ty_nd_item)
195+
128196
# pylint: disable=unused-argument
129-
def ol_nd_item_get_global_id_impl(nd_item, dim):
130-
# pylint: disable=no-value-for-parameter
131-
return _intrinsic_get_global_id(dim)
197+
def _intrinsic_get_group_gen(context, builder, sig, args):
198+
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
199+
# pylint: disable=protected-access
200+
return group_struct._getvalue()
132201

133-
return ol_nd_item_get_global_id_impl
202+
return sig, _intrinsic_get_group_gen
134203

135204

136205
@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def get_id(self, idx):
6161
"""
6262
return self._index[idx]
6363

64+
def get_range(self, idx):
65+
"""Get the range size for a specific dimension.
66+
67+
Returns:
68+
int: The size
69+
"""
70+
return self._extent[idx]
71+
6472
@property
6573
def ndim(self) -> int:
6674
"""Returns the rank of a Item object.
@@ -117,6 +125,22 @@ def get_local_linear_id(self):
117125
"""
118126
return self._local_item.get_linear_id()
119127

128+
def get_global_range(self, idx):
129+
"""Get the global range size for a specific dimension.
130+
131+
Returns:
132+
int: The size
133+
"""
134+
return self._global_item.get_range(idx)
135+
136+
def get_local_range(self, idx):
137+
"""Get the local range size for a specific dimension.
138+
139+
Returns:
140+
int: The size
141+
"""
142+
return self._local_item.get_range(idx)
143+
120144
def get_group(self):
121145
"""Returns the group.
122146

0 commit comments

Comments
 (0)