Skip to content

Commit 674ef5b

Browse files
committed
Wrap const indexing overloads into function
1 parent 7f1880d commit 674ef5b

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,31 @@ def ol_item_get_index_impl(item, dim):
180180
return ol_item_gen_index
181181

182182

183-
_index_const_overload_methods = [
184-
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
185-
(ItemType, "get_range", _intrinsic_spirv_global_size),
186-
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
187-
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
188-
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
189-
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
190-
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
191-
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
192-
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
193-
]
194-
195-
for index_overload in _index_const_overload_methods:
196-
_type, method, _intrinsic = index_overload
197-
198-
ol_index_func = generate_index_overload(_type, _intrinsic)
199-
200-
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
201-
ol_index_func
202-
)
183+
def register_index_const_methods():
184+
"""Register indexing related methods that can be defined as spirv const."""
185+
_index_const_overload_methods = [
186+
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
187+
(ItemType, "get_range", _intrinsic_spirv_global_size),
188+
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
189+
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
190+
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
191+
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
192+
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
193+
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
194+
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
195+
]
196+
197+
for index_overload in _index_const_overload_methods:
198+
_type, method, _intrinsic = index_overload
199+
200+
ol_index_func = generate_index_overload(_type, _intrinsic)
201+
202+
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
203+
ol_index_func
204+
)
205+
206+
207+
register_index_const_methods()
203208

204209

205210
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)

0 commit comments

Comments
 (0)