Skip to content

Commit 559af40

Browse files
committed
Feat: atomics 2; unsafe boogaloo
1 parent cba6d1c commit 559af40

File tree

7 files changed

+726
-10
lines changed

7 files changed

+726
-10
lines changed

crates/cuda_builder/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
449449
}
450450
}
451451

452+
let arch = format!("{:?}0", builder.arch);
453+
cargo.env("CUDA_ARCH", arch.strip_prefix("Compute").unwrap());
454+
452455
let cargo_encoded_rustflags = join_checking_for_separators(rustflags, "\x1f");
453456

454457
let build = cargo

crates/cuda_std/src/atomic.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,157 @@
1818
//! Therefore we chose to go with the approach of implementing all atomics inside cuda_std. In the future, we may support
1919
//! a subset of core atomics, but for now, you will have to use cuda_std atomics.
2020
21+
#![allow(unused_unsafe)]
22+
2123
pub mod intrinsics;
24+
pub mod mid;
25+
26+
use core::cell::UnsafeCell;
27+
use core::sync::atomic::Ordering;
28+
use paste::paste;
29+
30+
#[allow(dead_code)]
31+
fn fail_order(order: Ordering) -> Ordering {
32+
match order {
33+
Ordering::Release | Ordering::Relaxed => Ordering::Relaxed,
34+
Ordering::Acquire | Ordering::AcqRel => Ordering::Acquire,
35+
Ordering::SeqCst => Ordering::SeqCst,
36+
x => x,
37+
}
38+
}
39+
40+
macro_rules! scope_doc {
41+
(device) => {
42+
"a single device (GPU)."
43+
};
44+
(block) => {
45+
"a single thread block (also called a CTA, cooperative thread array)."
46+
};
47+
(system) => {
48+
"the entire system."
49+
};
50+
}
51+
52+
macro_rules! safety_doc {
53+
($($unsafety:ident)?) => {
54+
$(
55+
concat!(
56+
"# Safety\n",
57+
concat!("This function is ", stringify!($unsafety), " because it does not synchronize\n"),
58+
"across the entire GPU or System, which leaves it open for data races if used incorrectly"
59+
)
60+
)?
61+
};
62+
}
63+
64+
macro_rules! atomic_float {
65+
($float_ty:ident, $atomic_ty:ident, $align:tt, $scope:ident, $width:tt $(,$unsafety:ident)?) => {
66+
#[doc = concat!("A ", stringify!($width), "-bit float type which can be safely shared between threads and synchronizes across ", scope_doc!($scope))]
67+
///
68+
/// This type is guaranteed to have the same memory representation as the underlying integer
69+
/// type [`
70+
#[doc = stringify!($float_ty)]
71+
/// `].
72+
///
73+
/// The functions on this type map to hardware instructions on CUDA platforms, and are emulated
74+
/// with a CAS loop on the CPU (non-CUDA targets).
75+
#[repr(C, align($align))]
76+
pub struct $atomic_ty {
77+
v: UnsafeCell<$float_ty>,
78+
}
79+
80+
// SAFETY: atomic ops make sure this is fine
81+
unsafe impl Sync for $atomic_ty {}
82+
83+
impl $atomic_ty {
84+
paste! {
85+
/// Creates a new atomic float.
86+
pub const fn new(v: $float_ty) -> $atomic_ty {
87+
Self {
88+
v: UnsafeCell::new(v),
89+
}
90+
}
91+
92+
#[cfg(not(target_os = "cuda"))]
93+
fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] {
94+
// SAFETY: AtomicU32/U64 pointers are compatible with UnsafeCell<u32/u64>.
95+
unsafe {
96+
core::mem::transmute(self)
97+
}
98+
}
99+
100+
#[cfg(not(target_os = "cuda"))]
101+
fn update_with(&self, order: Ordering, mut func: impl FnMut($float_ty) -> $float_ty) -> $float_ty {
102+
let res = self
103+
.as_atomic_bits()
104+
.fetch_update(order, fail_order(order), |prev| {
105+
Some(func($float_ty::from_bits(prev))).map($float_ty::to_bits)
106+
}).unwrap();
107+
$float_ty::from_bits(res)
108+
}
109+
110+
/// Adds to the current value, returning the previous value **before** the addition.
111+
///
112+
$(#[doc = safety_doc!($unsafety)])?
113+
pub $($unsafety)? fn fetch_add(&self, val: $float_ty, order: Ordering) -> $float_ty {
114+
#[cfg(target_os = "cuda")]
115+
// SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
116+
unsafe {
117+
mid::[<atomic_fetch_add_ $float_ty _ $scope>](self.v.get(), order, val)
118+
}
119+
#[cfg(not(target_os = "cuda"))]
120+
self.update_with(order, |v| v + val)
121+
}
122+
123+
/// Atomically loads the value behind this atomic.
124+
///
125+
/// `load` takes an [`Ordering`] argument which describes the memory ordering of this operation.
126+
/// Possible values are [`Ordering::SeqCst`], [`Ordering::Acquire`], and [`Ordering::Relaxed`].
127+
///
128+
/// # Panics
129+
///
130+
/// Panics if `order` is [`Ordering::Release`] or [`Ordering::AcqRel`].
131+
///
132+
$(#[doc = safety_doc!($unsafety)])?
133+
pub $($unsafety)? fn load(&self, order: Ordering) -> $float_ty {
134+
#[cfg(target_os = "cuda")]
135+
unsafe {
136+
let val = mid::[<atomic_load_ $width _ $scope>](self.v.get().cast(), order);
137+
$float_ty::from_bits(val)
138+
}
139+
#[cfg(not(target_os = "cuda"))]
140+
{
141+
let val = self.as_atomic_bits().load(order);
142+
$float_ty::from_bits(val)
143+
}
144+
}
145+
146+
/// Atomically stores a value into this atomic.
147+
///
148+
/// `store` takes an [`Ordering`] argument which describes the memory ordering of this operation.
149+
/// Possible values are [`Ordering::SeqCst`], [`Ordering::Release`], and [`Ordering::Relaxed`].
150+
///
151+
/// # Panics
152+
///
153+
/// Panics if `order` is [`Ordering::Acquire`] or [`Ordering::AcqRel`].
154+
///
155+
$(#[doc = safety_doc!($unsafety)])?
156+
pub $($unsafety)? fn store(&self, val: $float_ty, order: Ordering) {
157+
#[cfg(target_os = "cuda")]
158+
unsafe {
159+
mid::[<atomic_store_ $width _ $scope>](self.v.get().cast(), order, val.to_bits());
160+
}
161+
#[cfg(not(target_os = "cuda"))]
162+
self.as_atomic_bits().store(val.to_bits(), order);
163+
}
164+
}
165+
}
166+
};
167+
}
168+
169+
atomic_float!(f32, AtomicF32, 4, device, 32);
170+
atomic_float!(f64, AtomicF64, 8, device, 64);
171+
atomic_float!(f32, BlockAtomicF32, 4, block, 32, unsafe);
172+
atomic_float!(f64, BlockAtomicF64, 8, block, 64, unsafe);
173+
atomic_float!(f32, SystemAtomicF32, 4, device, 32);
174+
atomic_float!(f64, SystemAtomicF64, 8, device, 64);

0 commit comments

Comments
 (0)