|
18 | 18 | //! Therefore we chose to go with the approach of implementing all atomics inside cuda_std. In the future, we may support
|
19 | 19 | //! a subset of core atomics, but for now, you will have to use cuda_std atomics.
|
20 | 20 |
|
| 21 | +#![allow(unused_unsafe)] |
| 22 | + |
21 | 23 | pub mod intrinsics;
|
| 24 | +pub mod mid; |
| 25 | + |
| 26 | +use core::cell::UnsafeCell; |
| 27 | +use core::sync::atomic::Ordering; |
| 28 | +use paste::paste; |
| 29 | + |
| 30 | +#[allow(dead_code)] |
| 31 | +fn fail_order(order: Ordering) -> Ordering { |
| 32 | + match order { |
| 33 | + Ordering::Release | Ordering::Relaxed => Ordering::Relaxed, |
| 34 | + Ordering::Acquire | Ordering::AcqRel => Ordering::Acquire, |
| 35 | + Ordering::SeqCst => Ordering::SeqCst, |
| 36 | + x => x, |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +macro_rules! scope_doc { |
| 41 | + (device) => { |
| 42 | + "a single device (GPU)." |
| 43 | + }; |
| 44 | + (block) => { |
| 45 | + "a single thread block (also called a CTA, cooperative thread array)." |
| 46 | + }; |
| 47 | + (system) => { |
| 48 | + "the entire system." |
| 49 | + }; |
| 50 | +} |
| 51 | + |
| 52 | +macro_rules! safety_doc { |
| 53 | + ($($unsafety:ident)?) => { |
| 54 | + $( |
| 55 | + concat!( |
| 56 | + "# Safety\n", |
| 57 | + concat!("This function is ", stringify!($unsafety), " because it does not synchronize\n"), |
| 58 | + "across the entire GPU or System, which leaves it open for data races if used incorrectly" |
| 59 | + ) |
| 60 | + )? |
| 61 | + }; |
| 62 | +} |
| 63 | + |
| 64 | +macro_rules! atomic_float { |
| 65 | + ($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => { |
| 66 | + #[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))] |
| 67 | + /// |
| 68 | + /// This type is guaranteed to have the same memory representation as the underlying integer |
| 69 | + /// type [` |
| 70 | + #[doc = stringify!($float_ty)] |
| 71 | + /// `]. |
| 72 | + /// |
| 73 | + /// The functions on this type map to hardware instructions on CUDA platforms, and are emulated |
| 74 | + /// with a CAS loop on the CPU (non-CUDA targets). |
| 75 | + #[repr(C, align($align))] |
| 76 | + pub struct $atomic_ty { |
| 77 | + v: UnsafeCell<$float_ty>, |
| 78 | + } |
| 79 | + |
| 80 | + // SAFETY: atomic ops make sure this is fine |
| 81 | + unsafe impl Sync for $atomic_ty {} |
| 82 | + |
| 83 | + impl $atomic_ty { |
| 84 | + paste! { |
| 85 | + /// Creates a new atomic float. |
| 86 | + pub const fn new(v: $float_ty) -> $atomic_ty { |
| 87 | + Self { |
| 88 | + v: UnsafeCell::new(v), |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + #[cfg(not(target_os = "cuda"))] |
| 93 | + fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] { |
| 94 | + // SAFETY: AtomicU32/U64 pointers are compatible with UnsafeCell<u32/u64>. |
| 95 | + unsafe { |
| 96 | + core::mem::transmute(self) |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + #[cfg(not(target_os = "cuda"))] |
| 101 | + fn update_with(&self, order: Ordering, mut func: impl FnMut($float_ty) -> $float_ty) -> $float_ty { |
| 102 | + let res = self |
| 103 | + .as_atomic_bits() |
| 104 | + .fetch_update(order, fail_order(order), |prev| { |
| 105 | + Some(func($float_ty::from_bits(prev))).map($float_ty::to_bits) |
| 106 | + }).unwrap(); |
| 107 | + $float_ty::from_bits(res) |
| 108 | + } |
| 109 | + |
| 110 | + /// Adds to the current value, returning the previous value **before** the addition. |
| 111 | + /// |
| 112 | + $(#[doc = safety_doc!($unsafety)])? |
| 113 | + pub $($unsafety)? fn fetch_add(&self, val: $float_ty, order: Ordering) -> $float_ty { |
| 114 | + #[cfg(target_os = "cuda")] |
| 115 | + // SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid. |
| 116 | + unsafe { |
| 117 | + mid::[<atomic_fetch_add_ $float_ty _ $scope>](self.v.get(), order, val) |
| 118 | + } |
| 119 | + #[cfg(not(target_os = "cuda"))] |
| 120 | + self.update_with(order, |v| v + val) |
| 121 | + } |
| 122 | + |
| 123 | + /// Atomically loads the value behind this atomic. |
| 124 | + /// |
| 125 | + /// `load` takes an [`Ordering`] argument which describes the memory ordering of this operation. |
| 126 | + /// Possible values are [`Ordering::SeqCst`], [`Ordering::Acquire`], and [`Ordering::Relaxed`]. |
| 127 | + /// |
| 128 | + /// # Panics |
| 129 | + /// |
| 130 | + /// Panics if `order` is [`Ordering::Release`] or [`Ordering::AcqRel`]. |
| 131 | + /// |
| 132 | + $(#[doc = safety_doc!($unsafety)])? |
| 133 | + pub $($unsafety)? fn load(&self, order: Ordering) -> $float_ty { |
| 134 | + #[cfg(target_os = "cuda")] |
| 135 | + unsafe { |
| 136 | + let val = mid::[<atomic_load_ $width _ $scope>](self.v.get().cast(), order); |
| 137 | + $float_ty::from_bits(val) |
| 138 | + } |
| 139 | + #[cfg(not(target_os = "cuda"))] |
| 140 | + { |
| 141 | + let val = self.as_atomic_bits().load(order); |
| 142 | + $float_ty::from_bits(val) |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + /// Atomically stores a value into this atomic. |
| 147 | + /// |
| 148 | + /// `store` takes an [`Ordering`] argument which describes the memory ordering of this operation. |
| 149 | + /// Possible values are [`Ordering::SeqCst`], [`Ordering::Release`], and [`Ordering::Relaxed`]. |
| 150 | + /// |
| 151 | + /// # Panics |
| 152 | + /// |
| 153 | + /// Panics if `order` is [`Ordering::Acquire`] or [`Ordering::AcqRel`]. |
| 154 | + /// |
| 155 | + $(#[doc = safety_doc!($unsafety)])? |
| 156 | + pub $($unsafety)? fn store(&self, val: $float_ty, order: Ordering) { |
| 157 | + #[cfg(target_os = "cuda")] |
| 158 | + unsafe { |
| 159 | + mid::[<atomic_store_ $width _ $scope>](self.v.get().cast(), order, val.to_bits()); |
| 160 | + } |
| 161 | + #[cfg(not(target_os = "cuda"))] |
| 162 | + self.as_atomic_bits().store(val.to_bits(), order); |
| 163 | + } |
| 164 | + } |
| 165 | + } |
| 166 | + }; |
| 167 | +} |
| 168 | + |
| 169 | +atomic_float!(f32, AtomicF32, 4, device, 32); |
| 170 | +atomic_float!(f64, AtomicF64, 8, device, 64); |
| 171 | +atomic_float!(f32, BlockAtomicF32, 4, block, 32, unsafe); |
| 172 | +atomic_float!(f64, BlockAtomicF64, 8, block, 64, unsafe); |
| 173 | +atomic_float!(f32, SystemAtomicF32, 4, device, 32); |
| 174 | +atomic_float!(f64, SystemAtomicF64, 8, device, 64); |
0 commit comments