24
24
get_memory_semantics_mask ,
25
25
get_scope ,
26
26
)
27
+ from .spv_fn_generator import (
28
+ get_or_insert_atomic_load_fn ,
29
+ get_or_insert_spv_atomic_exchange_fn ,
30
+ get_or_insert_spv_atomic_store_fn ,
31
+ )
27
32
28
33
29
34
def _parse_enum_or_int_literal_ (literal_int ) -> int :
@@ -217,44 +222,22 @@ def _intrinsic_load(
217
222
218
223
def _intrinsic_load_gen (context , builder , sig , args ):
219
224
atomic_ref_ty = sig .args [0 ]
220
- atomic_ref_dtype = atomic_ref_ty .dtype
221
- retty = context .get_value_type (atomic_ref_dtype )
222
-
223
- data_attr_pos = context .data_model_manager .lookup (
224
- atomic_ref_ty
225
- ).get_field_position ("ref" )
226
-
227
- ptr_type = retty .as_pointer ()
228
- ptr_type .addrspace = atomic_ref_ty .address_space
229
-
230
- spirv_fn_arg_types = [
231
- ptr_type ,
232
- llvmir .IntType (32 ),
233
- llvmir .IntType (32 ),
234
- ]
235
-
236
- mangled_fn_name = ext_itanium_mangler .mangle_ext (
237
- "__spirv_AtomicLoad" ,
238
- [
239
- types .CPointer (atomic_ref_dtype , addrspace = ptr_type .addrspace ),
240
- "__spv.Scope.Flag" ,
241
- "__spv.MemorySemanticsMask.Flag" ,
242
- ],
225
+ fn = get_or_insert_atomic_load_fn (
226
+ context , builder .module , atomic_ref_ty
243
227
)
244
228
245
- fn = cgutils .get_or_insert_function (
246
- builder .module ,
247
- llvmir .FunctionType (retty , spirv_fn_arg_types ),
248
- mangled_fn_name ,
249
- )
250
- fn .calling_convention = CC_SPIR_FUNC
251
229
spirv_memory_semantics_mask = get_memory_semantics_mask (
252
230
atomic_ref_ty .memory_order
253
231
)
254
232
spirv_scope = get_scope (atomic_ref_ty .memory_scope )
255
233
256
234
fn_args = [
257
- builder .extract_value (args [0 ], data_attr_pos ),
235
+ builder .extract_value (
236
+ args [0 ],
237
+ context .data_model_manager .lookup (
238
+ atomic_ref_ty
239
+ ).get_field_position ("ref" ),
240
+ ),
258
241
context .get_constant (types .int32 , spirv_scope ),
259
242
context .get_constant (types .int32 , spirv_memory_semantics_mask ),
260
243
]
@@ -264,76 +247,37 @@ def _intrinsic_load_gen(context, builder, sig, args):
264
247
return sig , _intrinsic_load_gen
265
248
266
249
267
- def _store_exchange_intrisic_helper (context , builder , sig , ol_info : dict ):
268
- atomic_ref_ty = sig .args [0 ]
269
- atomic_ref_dtype = atomic_ref_ty .dtype
270
-
271
- ptr_type = context .get_value_type (atomic_ref_dtype ).as_pointer ()
272
- ptr_type .addrspace = atomic_ref_ty .address_space
273
-
274
- spirv_fn_arg_types = [
275
- ptr_type ,
276
- llvmir .IntType (32 ),
277
- llvmir .IntType (32 ),
278
- context .get_value_type (atomic_ref_dtype ),
279
- ]
280
-
281
- mangled_fn_name = ext_itanium_mangler .mangle_ext (
282
- ol_info ["name" ],
283
- [
284
- types .CPointer (atomic_ref_dtype , addrspace = ptr_type .addrspace ),
285
- "__spv.Scope.Flag" ,
286
- "__spv.MemorySemanticsMask.Flag" ,
287
- atomic_ref_dtype ,
288
- ],
289
- )
290
-
291
- fn = cgutils .get_or_insert_function (
292
- builder .module ,
293
- llvmir .FunctionType (ol_info ["retty" ], spirv_fn_arg_types ),
294
- mangled_fn_name ,
295
- )
296
- fn .calling_convention = CC_SPIR_FUNC
297
-
298
- fn_args = [
299
- builder .extract_value (
300
- ol_info ["args" ][0 ],
301
- context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
302
- "ref"
303
- ),
304
- ),
305
- context .get_constant (
306
- types .int32 , get_scope (atomic_ref_ty .memory_scope )
307
- ),
308
- context .get_constant (
309
- types .int32 , get_memory_semantics_mask (atomic_ref_ty .memory_order )
310
- ),
311
- ol_info ["args" ][1 ],
312
- ]
313
-
314
- return builder .call (fn , fn_args )
315
-
316
-
317
250
@intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
318
251
def _intrinsic_store (
319
252
ty_context , ty_atomic_ref , ty_val
320
253
): # pylint: disable=unused-argument
321
254
sig = types .void (ty_atomic_ref , ty_val )
322
255
323
256
def _intrinsic_store_gen (context , builder , sig , args ):
324
- _store_exchange_intrisic_helper (
325
- context ,
326
- builder ,
327
- sig ,
328
- # dict containing arguments, return type,
329
- # spirv fn name driven by pylint too-many-args
330
- {
331
- "args" : args ,
332
- "retty" : llvmir .VoidType (),
333
- "name" : "__spirv_AtomicStore" ,
334
- },
257
+ atomic_ref_ty = sig .args [0 ]
258
+ atomic_store_fn = get_or_insert_spv_atomic_store_fn (
259
+ context , builder .module , atomic_ref_ty
335
260
)
336
261
262
+ atomic_store_fn_args = [
263
+ builder .extract_value (
264
+ args [0 ],
265
+ context .data_model_manager .lookup (
266
+ atomic_ref_ty
267
+ ).get_field_position ("ref" ),
268
+ ),
269
+ context .get_constant (
270
+ types .int32 , get_scope (atomic_ref_ty .memory_scope )
271
+ ),
272
+ context .get_constant (
273
+ types .int32 ,
274
+ get_memory_semantics_mask (atomic_ref_ty .memory_order ),
275
+ ),
276
+ args [1 ],
277
+ ]
278
+
279
+ builder .call (atomic_store_fn , atomic_store_fn_args )
280
+
337
281
return sig , _intrinsic_store_gen
338
282
339
283
@@ -344,19 +288,30 @@ def _intrinsic_exchange(
344
288
sig = ty_atomic_ref .dtype (ty_atomic_ref , ty_val )
345
289
346
290
def _intrinsic_exchange_gen (context , builder , sig , args ):
347
- return _store_exchange_intrisic_helper (
348
- context ,
349
- builder ,
350
- sig ,
351
- # dict containing arguments, return type,
352
- # spirv fn name driven by pylint too-many-args
353
- {
354
- "args" : args ,
355
- "retty" : context .get_value_type (sig .args [0 ].dtype ),
356
- "name" : "__spirv_AtomicExchange" ,
357
- },
291
+ atomic_ref_ty = sig .args [0 ]
292
+ atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn (
293
+ context , builder .module , atomic_ref_ty
358
294
)
359
295
296
+ atomic_exchange_fn_args = [
297
+ builder .extract_value (
298
+ args [0 ],
299
+ context .data_model_manager .lookup (
300
+ atomic_ref_ty
301
+ ).get_field_position ("ref" ),
302
+ ),
303
+ context .get_constant (
304
+ types .int32 , get_scope (atomic_ref_ty .memory_scope )
305
+ ),
306
+ context .get_constant (
307
+ types .int32 ,
308
+ get_memory_semantics_mask (atomic_ref_ty .memory_order ),
309
+ ),
310
+ args [1 ],
311
+ ]
312
+
313
+ return builder .call (atomic_exchange_fn , atomic_exchange_fn_args )
314
+
360
315
return sig , _intrinsic_exchange_gen
361
316
362
317
0 commit comments