@@ -231,27 +231,52 @@ def _intrinsic_load(
231
231
232
232
def _intrinsic_load_gen (context , builder , sig , args ):
233
233
atomic_ref_ty = sig .args [0 ]
234
- fn = get_or_insert_atomic_load_fn (
235
- context , builder .module , atomic_ref_ty
236
- )
237
234
238
- spirv_memory_semantics_mask = get_memory_semantics_mask (
239
- atomic_ref_ty .memory_order
235
+ atomic_ref_ptr = builder .extract_value (
236
+ args [0 ],
237
+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
238
+ "ref"
239
+ ),
240
240
)
241
- spirv_scope = get_scope (atomic_ref_ty .memory_scope )
241
+ if sig .args [0 ].dtype == types .float32 :
242
+ atomic_ref_ptr = builder .bitcast (
243
+ atomic_ref_ptr ,
244
+ llvmir .PointerType (
245
+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
246
+ ),
247
+ )
248
+ elif sig .args [0 ].dtype == types .float64 :
249
+ atomic_ref_ptr = builder .bitcast (
250
+ atomic_ref_ptr ,
251
+ llvmir .PointerType (
252
+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
253
+ ),
254
+ )
242
255
243
256
fn_args = [
244
- builder .extract_value (
245
- args [0 ],
246
- context .data_model_manager .lookup (
247
- atomic_ref_ty
248
- ).get_field_position ("ref" ),
257
+ atomic_ref_ptr ,
258
+ context .get_constant (
259
+ types .int32 , get_scope (atomic_ref_ty .memory_scope )
260
+ ),
261
+ context .get_constant (
262
+ types .int32 ,
263
+ get_memory_semantics_mask (atomic_ref_ty .memory_order ),
249
264
),
250
- context .get_constant (types .int32 , spirv_scope ),
251
- context .get_constant (types .int32 , spirv_memory_semantics_mask ),
252
265
]
253
266
254
- return builder .call (fn , fn_args )
267
+ ret_val = builder .call (
268
+ get_or_insert_atomic_load_fn (
269
+ context , builder .module , atomic_ref_ty
270
+ ),
271
+ fn_args ,
272
+ )
273
+
274
+ if sig .args [0 ].dtype == types .float32 :
275
+ ret_val = builder .bitcast (ret_val , llvmir .FloatType ())
276
+ elif sig .args [0 ].dtype == types .float64 :
277
+ ret_val = builder .bitcast (ret_val , llvmir .DoubleType ())
278
+
279
+ return ret_val
255
280
256
281
return sig , _intrinsic_load_gen
257
282
@@ -264,28 +289,49 @@ def _intrinsic_store(
264
289
265
290
def _intrinsic_store_gen (context , builder , sig , args ):
266
291
atomic_ref_ty = sig .args [0 ]
267
- atomic_store_fn = get_or_insert_spv_atomic_store_fn (
268
- context , builder .module , atomic_ref_ty
292
+
293
+ store_arg = args [1 ]
294
+ atomic_ref_ptr = builder .extract_value (
295
+ args [0 ],
296
+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
297
+ "ref"
298
+ ),
269
299
)
300
+ if sig .args [0 ].dtype == types .float32 :
301
+ atomic_ref_ptr = builder .bitcast (
302
+ atomic_ref_ptr ,
303
+ llvmir .PointerType (
304
+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
305
+ ),
306
+ )
307
+ store_arg = builder .bitcast (store_arg , llvmir .IntType (32 ))
308
+ elif sig .args [0 ].dtype == types .float64 :
309
+ atomic_ref_ptr = builder .bitcast (
310
+ atomic_ref_ptr ,
311
+ llvmir .PointerType (
312
+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
313
+ ),
314
+ )
315
+ store_arg = builder .bitcast (store_arg , llvmir .IntType (64 ))
270
316
271
317
atomic_store_fn_args = [
272
- builder .extract_value (
273
- args [0 ],
274
- context .data_model_manager .lookup (
275
- atomic_ref_ty
276
- ).get_field_position ("ref" ),
277
- ),
318
+ atomic_ref_ptr ,
278
319
context .get_constant (
279
320
types .int32 , get_scope (atomic_ref_ty .memory_scope )
280
321
),
281
322
context .get_constant (
282
323
types .int32 ,
283
324
get_memory_semantics_mask (atomic_ref_ty .memory_order ),
284
325
),
285
- args [ 1 ] ,
326
+ store_arg ,
286
327
]
287
328
288
- builder .call (atomic_store_fn , atomic_store_fn_args )
329
+ builder .call (
330
+ get_or_insert_spv_atomic_store_fn (
331
+ context , builder .module , atomic_ref_ty
332
+ ),
333
+ atomic_store_fn_args ,
334
+ )
289
335
290
336
return sig , _intrinsic_store_gen
291
337
@@ -298,28 +344,56 @@ def _intrinsic_exchange(
298
344
299
345
def _intrinsic_exchange_gen (context , builder , sig , args ):
300
346
atomic_ref_ty = sig .args [0 ]
301
- atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn (
302
- context , builder .module , atomic_ref_ty
347
+
348
+ exchange_arg = args [1 ]
349
+ atomic_ref_ptr = builder .extract_value (
350
+ args [0 ],
351
+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
352
+ "ref"
353
+ ),
303
354
)
355
+ if sig .args [0 ].dtype == types .float32 :
356
+ atomic_ref_ptr = builder .bitcast (
357
+ atomic_ref_ptr ,
358
+ llvmir .PointerType (
359
+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
360
+ ),
361
+ )
362
+ exchange_arg = builder .bitcast (exchange_arg , llvmir .IntType (32 ))
363
+ elif sig .args [0 ].dtype == types .float64 :
364
+ atomic_ref_ptr = builder .bitcast (
365
+ atomic_ref_ptr ,
366
+ llvmir .PointerType (
367
+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
368
+ ),
369
+ )
370
+ exchange_arg = builder .bitcast (exchange_arg , llvmir .IntType (64 ))
304
371
305
372
atomic_exchange_fn_args = [
306
- builder .extract_value (
307
- args [0 ],
308
- context .data_model_manager .lookup (
309
- atomic_ref_ty
310
- ).get_field_position ("ref" ),
311
- ),
373
+ atomic_ref_ptr ,
312
374
context .get_constant (
313
375
types .int32 , get_scope (atomic_ref_ty .memory_scope )
314
376
),
315
377
context .get_constant (
316
378
types .int32 ,
317
379
get_memory_semantics_mask (atomic_ref_ty .memory_order ),
318
380
),
319
- args [ 1 ] ,
381
+ exchange_arg ,
320
382
]
321
383
322
- return builder .call (atomic_exchange_fn , atomic_exchange_fn_args )
384
+ ret_val = builder .call (
385
+ get_or_insert_spv_atomic_exchange_fn (
386
+ context , builder .module , atomic_ref_ty
387
+ ),
388
+ atomic_exchange_fn_args ,
389
+ )
390
+
391
+ if sig .args [0 ].dtype == types .float32 :
392
+ ret_val = builder .bitcast (ret_val , llvmir .FloatType ())
393
+ elif sig .args [0 ].dtype == types .float64 :
394
+ ret_val = builder .bitcast (ret_val , llvmir .DoubleType ())
395
+
396
+ return ret_val
323
397
324
398
return sig , _intrinsic_exchange_gen
325
399
0 commit comments