@@ -231,27 +231,52 @@ def _intrinsic_load(
231231
232232 def _intrinsic_load_gen (context , builder , sig , args ):
233233 atomic_ref_ty = sig .args [0 ]
234- fn = get_or_insert_atomic_load_fn (
235- context , builder .module , atomic_ref_ty
236- )
237234
238- spirv_memory_semantics_mask = get_memory_semantics_mask (
239- atomic_ref_ty .memory_order
235+ atomic_ref_ptr = builder .extract_value (
236+ args [0 ],
237+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
238+ "ref"
239+ ),
240240 )
241- spirv_scope = get_scope (atomic_ref_ty .memory_scope )
241+ if sig .args [0 ].dtype == types .float32 :
242+ atomic_ref_ptr = builder .bitcast (
243+ atomic_ref_ptr ,
244+ llvmir .PointerType (
245+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
246+ ),
247+ )
248+ elif sig .args [0 ].dtype == types .float64 :
249+ atomic_ref_ptr = builder .bitcast (
250+ atomic_ref_ptr ,
251+ llvmir .PointerType (
252+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
253+ ),
254+ )
242255
243256 fn_args = [
244- builder .extract_value (
245- args [0 ],
246- context .data_model_manager .lookup (
247- atomic_ref_ty
248- ).get_field_position ("ref" ),
257+ atomic_ref_ptr ,
258+ context .get_constant (
259+ types .int32 , get_scope (atomic_ref_ty .memory_scope )
260+ ),
261+ context .get_constant (
262+ types .int32 ,
263+ get_memory_semantics_mask (atomic_ref_ty .memory_order ),
249264 ),
250- context .get_constant (types .int32 , spirv_scope ),
251- context .get_constant (types .int32 , spirv_memory_semantics_mask ),
252265 ]
253266
254- return builder .call (fn , fn_args )
267+ ret_val = builder .call (
268+ get_or_insert_atomic_load_fn (
269+ context , builder .module , atomic_ref_ty
270+ ),
271+ fn_args ,
272+ )
273+
274+ if sig .args [0 ].dtype == types .float32 :
275+ ret_val = builder .bitcast (ret_val , llvmir .FloatType ())
276+ elif sig .args [0 ].dtype == types .float64 :
277+ ret_val = builder .bitcast (ret_val , llvmir .DoubleType ())
278+
279+ return ret_val
255280
256281 return sig , _intrinsic_load_gen
257282
@@ -264,28 +289,49 @@ def _intrinsic_store(
264289
265290 def _intrinsic_store_gen (context , builder , sig , args ):
266291 atomic_ref_ty = sig .args [0 ]
267- atomic_store_fn = get_or_insert_spv_atomic_store_fn (
268- context , builder .module , atomic_ref_ty
292+
293+ store_arg = args [1 ]
294+ atomic_ref_ptr = builder .extract_value (
295+ args [0 ],
296+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
297+ "ref"
298+ ),
269299 )
300+ if sig .args [0 ].dtype == types .float32 :
301+ atomic_ref_ptr = builder .bitcast (
302+ atomic_ref_ptr ,
303+ llvmir .PointerType (
304+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
305+ ),
306+ )
307+ store_arg = builder .bitcast (store_arg , llvmir .IntType (32 ))
308+ elif sig .args [0 ].dtype == types .float64 :
309+ atomic_ref_ptr = builder .bitcast (
310+ atomic_ref_ptr ,
311+ llvmir .PointerType (
312+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
313+ ),
314+ )
315+ store_arg = builder .bitcast (store_arg , llvmir .IntType (64 ))
270316
271317 atomic_store_fn_args = [
272- builder .extract_value (
273- args [0 ],
274- context .data_model_manager .lookup (
275- atomic_ref_ty
276- ).get_field_position ("ref" ),
277- ),
318+ atomic_ref_ptr ,
278319 context .get_constant (
279320 types .int32 , get_scope (atomic_ref_ty .memory_scope )
280321 ),
281322 context .get_constant (
282323 types .int32 ,
283324 get_memory_semantics_mask (atomic_ref_ty .memory_order ),
284325 ),
285- args [ 1 ] ,
326+ store_arg ,
286327 ]
287328
288- builder .call (atomic_store_fn , atomic_store_fn_args )
329+ builder .call (
330+ get_or_insert_spv_atomic_store_fn (
331+ context , builder .module , atomic_ref_ty
332+ ),
333+ atomic_store_fn_args ,
334+ )
289335
290336 return sig , _intrinsic_store_gen
291337
@@ -298,28 +344,56 @@ def _intrinsic_exchange(
298344
299345 def _intrinsic_exchange_gen (context , builder , sig , args ):
300346 atomic_ref_ty = sig .args [0 ]
301- atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn (
302- context , builder .module , atomic_ref_ty
347+
348+ exchange_arg = args [1 ]
349+ atomic_ref_ptr = builder .extract_value (
350+ args [0 ],
351+ context .data_model_manager .lookup (atomic_ref_ty ).get_field_position (
352+ "ref"
353+ ),
303354 )
355+ if sig .args [0 ].dtype == types .float32 :
356+ atomic_ref_ptr = builder .bitcast (
357+ atomic_ref_ptr ,
358+ llvmir .PointerType (
359+ llvmir .IntType (32 ), addrspace = sig .args [0 ].address_space
360+ ),
361+ )
362+ exchange_arg = builder .bitcast (exchange_arg , llvmir .IntType (32 ))
363+ elif sig .args [0 ].dtype == types .float64 :
364+ atomic_ref_ptr = builder .bitcast (
365+ atomic_ref_ptr ,
366+ llvmir .PointerType (
367+ llvmir .IntType (64 ), addrspace = sig .args [0 ].address_space
368+ ),
369+ )
370+ exchange_arg = builder .bitcast (exchange_arg , llvmir .IntType (64 ))
304371
305372 atomic_exchange_fn_args = [
306- builder .extract_value (
307- args [0 ],
308- context .data_model_manager .lookup (
309- atomic_ref_ty
310- ).get_field_position ("ref" ),
311- ),
373+ atomic_ref_ptr ,
312374 context .get_constant (
313375 types .int32 , get_scope (atomic_ref_ty .memory_scope )
314376 ),
315377 context .get_constant (
316378 types .int32 ,
317379 get_memory_semantics_mask (atomic_ref_ty .memory_order ),
318380 ),
319- args [ 1 ] ,
381+ exchange_arg ,
320382 ]
321383
322- return builder .call (atomic_exchange_fn , atomic_exchange_fn_args )
384+ ret_val = builder .call (
385+ get_or_insert_spv_atomic_exchange_fn (
386+ context , builder .module , atomic_ref_ty
387+ ),
388+ atomic_exchange_fn_args ,
389+ )
390+
391+ if sig .args [0 ].dtype == types .float32 :
392+ ret_val = builder .bitcast (ret_val , llvmir .FloatType ())
393+ elif sig .args [0 ].dtype == types .float64 :
394+ ret_val = builder .bitcast (ret_val , llvmir .DoubleType ())
395+
396+ return ret_val
323397
324398 return sig , _intrinsic_exchange_gen
325399
0 commit comments