|
34 | 34 | )
|
35 | 35 | from .spv_fn_generator import (
|
36 | 36 | get_or_insert_atomic_load_fn,
|
| 37 | + get_or_insert_spv_atomic_compare_exchange_fn, |
37 | 38 | get_or_insert_spv_atomic_exchange_fn,
|
38 | 39 | get_or_insert_spv_atomic_store_fn,
|
39 | 40 | )
|
@@ -323,6 +324,108 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
|
323 | 324 | return sig, _intrinsic_exchange_gen
|
324 | 325 |
|
325 | 326 |
|
| 327 | +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) |
| 328 | +def _intrinsic_compare_exchange( |
| 329 | + ty_context, # pylint: disable=unused-argument |
| 330 | + ty_atomic_ref, |
| 331 | + ty_expected_ref, |
| 332 | + ty_desired, |
| 333 | + ty_expected_idx, |
| 334 | +): |
| 335 | + sig = types.boolean( |
| 336 | + ty_atomic_ref, ty_expected_ref, ty_desired, ty_expected_idx |
| 337 | + ) |
| 338 | + |
| 339 | + def _intrinsic_compare_exchange_gen(context, builder, sig, args): |
| 340 | + # get pointer to expected[expected_idx] |
| 341 | + data_attr = builder.extract_value( |
| 342 | + args[1], |
| 343 | + context.data_model_manager.lookup(sig.args[1]).get_field_position( |
| 344 | + "data" |
| 345 | + ), |
| 346 | + ) |
| 347 | + with builder.goto_entry_block(): |
| 348 | + ptr_to_data_attr = builder.alloca(data_attr.type) |
| 349 | + builder.store(data_attr, ptr_to_data_attr) |
| 350 | + expected_ref_ptr = builder.gep( |
| 351 | + builder.load(ptr_to_data_attr), [args[3]] |
| 352 | + ) |
| 353 | + |
| 354 | + expected_arg = builder.load(expected_ref_ptr) |
| 355 | + desired_arg = args[2] |
| 356 | + atomic_ref_ptr = builder.extract_value( |
| 357 | + args[0], |
| 358 | + context.data_model_manager.lookup(sig.args[0]).get_field_position( |
| 359 | + "ref" |
| 360 | + ), |
| 361 | + ) |
| 362 | + # add conditional bitcast for atomic_ref pointer, |
| 363 | + # expected[expected_idx], and desired |
| 364 | + if sig.args[0].dtype == types.float32: |
| 365 | + atomic_ref_ptr = builder.bitcast( |
| 366 | + atomic_ref_ptr, |
| 367 | + llvmir.PointerType( |
| 368 | + llvmir.IntType(32), addrspace=sig.args[0].address_space |
| 369 | + ), |
| 370 | + ) |
| 371 | + expected_arg = builder.bitcast(expected_arg, llvmir.IntType(32)) |
| 372 | + desired_arg = builder.bitcast(desired_arg, llvmir.IntType(32)) |
| 373 | + elif sig.args[0].dtype == types.float64: |
| 374 | + atomic_ref_ptr = builder.bitcast( |
| 375 | + atomic_ref_ptr, |
| 376 | + llvmir.PointerType( |
| 377 | + llvmir.IntType(64), addrspace=sig.args[0].address_space |
| 378 | + ), |
| 379 | + ) |
| 380 | + expected_arg = builder.bitcast(expected_arg, llvmir.IntType(64)) |
| 381 | + desired_arg = builder.bitcast(desired_arg, llvmir.IntType(64)) |
| 382 | + |
| 383 | + atomic_cmpexchg_fn_args = [ |
| 384 | + atomic_ref_ptr, |
| 385 | + context.get_constant( |
| 386 | + types.int32, get_scope(sig.args[0].memory_scope) |
| 387 | + ), |
| 388 | + context.get_constant( |
| 389 | + types.int32, |
| 390 | + get_memory_semantics_mask(sig.args[0].memory_order), |
| 391 | + ), |
| 392 | + context.get_constant( |
| 393 | + types.int32, |
| 394 | + get_memory_semantics_mask(sig.args[0].memory_order), |
| 395 | + ), |
| 396 | + desired_arg, |
| 397 | + expected_arg, |
| 398 | + ] |
| 399 | + |
| 400 | + ret_val = builder.call( |
| 401 | + get_or_insert_spv_atomic_compare_exchange_fn( |
| 402 | + context, builder.module, sig.args[0] |
| 403 | + ), |
| 404 | + atomic_cmpexchg_fn_args, |
| 405 | + ) |
| 406 | + |
| 407 | + # compare_exchange returns the old value stored in AtomicRef object. |
| 408 | + # If the return value is same as expected, then compare_exchange |
| 409 | + # succeeded in replacing AtomicRef object with desired. |
| 410 | + # If the return value is not same as expected, then store return |
| 411 | + # value in expected. |
| 412 | + # In either case, return result of cmp instruction. |
| 413 | + is_cmp_exchg_success = builder.icmp_signed("==", ret_val, expected_arg) |
| 414 | + |
| 415 | + with builder.if_else(is_cmp_exchg_success) as (then, otherwise): |
| 416 | + with then: |
| 417 | + pass |
| 418 | + with otherwise: |
| 419 | + if sig.args[0].dtype == types.float32: |
| 420 | + ret_val = builder.bitcast(ret_val, llvmir.FloatType()) |
| 421 | + elif sig.args[0].dtype == types.float64: |
| 422 | + ret_val = builder.bitcast(ret_val, llvmir.DoubleType()) |
| 423 | + builder.store(ret_val, expected_ref_ptr) |
| 424 | + return is_cmp_exchg_success |
| 425 | + |
| 426 | + return sig, _intrinsic_compare_exchange_gen |
| 427 | + |
| 428 | + |
326 | 429 | def _check_if_supported_ref(ref):
|
327 | 430 | supported = True
|
328 | 431 |
|
@@ -689,3 +792,50 @@ def ol_exchange_impl(atomic_ref, val):
|
689 | 792 | return _intrinsic_exchange(atomic_ref, val)
|
690 | 793 |
|
691 | 794 | return ol_exchange_impl
|
| 795 | + |
| 796 | + |
| 797 | +@overload_method( |
| 798 | + AtomicRefType, |
| 799 | + "compare_exchange", |
| 800 | + target=DPEX_KERNEL_EXP_TARGET_NAME, |
| 801 | +) |
| 802 | +def ol_compare_exchange( |
| 803 | + atomic_ref, |
| 804 | + expected_ref, |
| 805 | + desired, |
| 806 | + expected_idx=0, # pylint: disable=unused-argument |
| 807 | +): |
| 808 | + """SPIR-V overload for |
| 809 | + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange`. |
| 810 | +
|
| 811 | + Generates the same LLVM IR instruction as dpcpp for the |
| 812 | + `atomic_ref::compare_exchange_strong` function. |
| 813 | +
|
| 814 | + Raises: |
| 815 | + TypingError: When the dtype of the value passed to `compare_exchange` |
| 816 | + does not match the dtype of the AtomicRef type. |
| 817 | + """ |
| 818 | + |
| 819 | + _check_if_supported_ref(expected_ref) |
| 820 | + |
| 821 | + if atomic_ref.dtype != expected_ref.dtype: |
| 822 | + raise errors.TypingError( |
| 823 | + f"Type of value to compare_exchange: {expected_ref} does not match the " |
| 824 | + f"type of the reference: {atomic_ref.dtype} stored in the atomic ref." |
| 825 | + ) |
| 826 | + |
| 827 | + if atomic_ref.dtype != desired: |
| 828 | + raise errors.TypingError( |
| 829 | + f"Type of value to compare_exchange: {desired} does not match the " |
| 830 | + f"type of the reference: {atomic_ref.dtype} stored in the atomic ref." |
| 831 | + ) |
| 832 | + |
| 833 | + def ol_compare_exchange_impl( |
| 834 | + atomic_ref, expected_ref, desired, expected_idx=0 |
| 835 | + ): |
| 836 | + # pylint: disable=no-value-for-parameter |
| 837 | + return _intrinsic_compare_exchange( |
| 838 | + atomic_ref, expected_ref, desired, expected_idx |
| 839 | + ) |
| 840 | + |
| 841 | + return ol_compare_exchange_impl |
0 commit comments