@@ -180,26 +180,31 @@ def ol_item_get_index_impl(item, dim):
180
180
return ol_item_gen_index
181
181
182
182
183
- _index_const_overload_methods = [
184
- (ItemType , "get_id" , _intrinsic_spirv_global_invocation_id ),
185
- (ItemType , "get_range" , _intrinsic_spirv_global_size ),
186
- (NdItemType , "get_global_id" , _intrinsic_spirv_global_invocation_id ),
187
- (NdItemType , "get_local_id" , _intrinsic_spirv_local_invocation_id ),
188
- (NdItemType , "get_global_range" , _intrinsic_spirv_global_size ),
189
- (NdItemType , "get_local_range" , _intrinsic_spirv_workgroup_size ),
190
- (GroupType , "get_group_id" , _intrinsic_spirv_workgroup_id ),
191
- (GroupType , "get_group_range" , _intrinsic_spirv_numworkgroups ),
192
- (GroupType , "get_local_range" , _intrinsic_spirv_workgroup_size ),
193
- ]
194
-
195
- for index_overload in _index_const_overload_methods :
196
- _type , method , _intrinsic = index_overload
197
-
198
- ol_index_func = generate_index_overload (_type , _intrinsic )
199
-
200
- overload_method (_type , method , target = DPEX_KERNEL_EXP_TARGET_NAME )(
201
- ol_index_func
202
- )
183
+ def register_index_const_methods ():
184
+ """Register indexing related methods that can be defined as spirv const."""
185
+ _index_const_overload_methods = [
186
+ (ItemType , "get_id" , _intrinsic_spirv_global_invocation_id ),
187
+ (ItemType , "get_range" , _intrinsic_spirv_global_size ),
188
+ (NdItemType , "get_global_id" , _intrinsic_spirv_global_invocation_id ),
189
+ (NdItemType , "get_local_id" , _intrinsic_spirv_local_invocation_id ),
190
+ (NdItemType , "get_global_range" , _intrinsic_spirv_global_size ),
191
+ (NdItemType , "get_local_range" , _intrinsic_spirv_workgroup_size ),
192
+ (GroupType , "get_group_id" , _intrinsic_spirv_workgroup_id ),
193
+ (GroupType , "get_group_range" , _intrinsic_spirv_numworkgroups ),
194
+ (GroupType , "get_local_range" , _intrinsic_spirv_workgroup_size ),
195
+ ]
196
+
197
+ for index_overload in _index_const_overload_methods :
198
+ _type , method , _intrinsic = index_overload
199
+
200
+ ol_index_func = generate_index_overload (_type , _intrinsic )
201
+
202
+ overload_method (_type , method , target = DPEX_KERNEL_EXP_TARGET_NAME )(
203
+ ol_index_func
204
+ )
205
+
206
+
207
+ register_index_const_methods ()
203
208
204
209
205
210
@intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
0 commit comments