Skip to content

Commit 8d264ad

Browse files
committed
Feat: atomics 3; return of the atomics
1 parent 559af40 commit 8d264ad

File tree

2 files changed

+114
-16
lines changed

2 files changed

+114
-16
lines changed

crates/cuda_std/src/atomic.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ macro_rules! atomic_float {
8989
}
9090
}
9191

92+
/// Consumes the atomic and returns the contained value.
93+
pub fn into_inner(self) -> $float_ty {
94+
self.v.into_inner()
95+
}
96+
9297
#[cfg(not(target_os = "cuda"))]
9398
fn as_atomic_bits(&self) -> &core::sync::atomic::[<AtomicU $width>] {
9499
// SAFETY: AtomicU32/U64 pointers are compatible with UnsafeCell<u32/u64>.
@@ -120,6 +125,60 @@ macro_rules! atomic_float {
120125
self.update_with(order, |v| v + val)
121126
}
122127

128+
/// Subtracts from the current value, returning the previous value **before** the subtraction.
129+
///
130+
/// Note, this is actually implemented as `old + (-new)`, CUDA does not have a specialized sub instruction.
131+
///
132+
$(#[doc = safety_doc!($unsafety)])?
133+
pub $($unsafety)? fn fetch_sub(&self, val: $float_ty, order: Ordering) -> $float_ty {
134+
#[cfg(target_os = "cuda")]
135+
// SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
136+
unsafe {
137+
mid::[<atomic_fetch_sub_ $float_ty _ $scope>](self.v.get(), order, val)
138+
}
139+
#[cfg(not(target_os = "cuda"))]
140+
self.update_with(order, |v| v - val)
141+
}
142+
143+
/// Bitwise "and" with the current value. Returns the value **before** the "and".
144+
///
145+
$(#[doc = safety_doc!($unsafety)])?
146+
pub $($unsafety)? fn fetch_and(&self, val: $float_ty, order: Ordering) -> $float_ty {
147+
#[cfg(target_os = "cuda")]
148+
// SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
149+
unsafe {
150+
mid::[<atomic_fetch_and_ $float_ty _ $scope>](self.v.get(), order, val)
151+
}
152+
#[cfg(not(target_os = "cuda"))]
153+
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() & val.to_bits()))
154+
}
155+
156+
/// Bitwise "or" with the current value. Returns the value **before** the "or".
157+
///
158+
$(#[doc = safety_doc!($unsafety)])?
159+
pub $($unsafety)? fn fetch_or(&self, val: $float_ty, order: Ordering) -> $float_ty {
160+
#[cfg(target_os = "cuda")]
161+
// SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
162+
unsafe {
163+
mid::[<atomic_fetch_or_ $float_ty _ $scope>](self.v.get(), order, val)
164+
}
165+
#[cfg(not(target_os = "cuda"))]
166+
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() | val.to_bits()))
167+
}
168+
169+
/// Bitwise "xor" with the current value. Returns the value **before** the "xor".
170+
///
171+
$(#[doc = safety_doc!($unsafety)])?
172+
pub $($unsafety)? fn fetch_xor(&self, val: $float_ty, order: Ordering) -> $float_ty {
173+
#[cfg(target_os = "cuda")]
174+
// SAFETY: data races are prevented by atomic intrinsics and the pointer we get is valid.
175+
unsafe {
176+
mid::[<atomic_fetch_xor_ $float_ty _ $scope>](self.v.get(), order, val)
177+
}
178+
#[cfg(not(target_os = "cuda"))]
179+
self.update_with(order, |v| $float_ty::from_bits(v.to_bits() ^ val.to_bits()))
180+
}
181+
123182
/// Atomically loads the value behind this atomic.
124183
///
125184
/// `load` takes an [`Ordering`] argument which describes the memory ordering of this operation.

crates/cuda_std/src/atomic/intrinsics.rs

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,56 @@ use paste::paste;
66

77
#[gpu_only]
88
pub unsafe fn membar_device() {
9-
asm!("membar.gl");
9+
asm!("membar.gl;");
1010
}
1111

1212
#[gpu_only]
1313
pub unsafe fn membar_block() {
14-
asm!("membar.cta");
14+
asm!("membar.cta;");
1515
}
1616

1717
#[gpu_only]
1818
pub unsafe fn membar_system() {
19-
asm!("membar.sys");
19+
asm!("membar.sys;");
2020
}
2121

2222
#[gpu_only]
2323
pub unsafe fn fence_sc_device() {
24-
asm!("fence.sc.gl");
24+
asm!("fence.sc.gl;");
2525
}
2626

2727
#[gpu_only]
2828
pub unsafe fn fence_sc_block() {
29-
asm!("fence.sc.cta");
29+
asm!("fence.sc.cta;");
3030
}
3131

3232
#[gpu_only]
3333
pub unsafe fn fence_sc_system() {
34-
asm!("fence.sc.sys");
34+
asm!("fence.sc.sys;");
3535
}
3636

3737
#[gpu_only]
3838
pub unsafe fn fence_acqrel_device() {
39-
asm!("fence.acq_rel.gl");
39+
asm!("fence.acq_rel.gl;");
4040
}
4141

4242
#[gpu_only]
4343
pub unsafe fn fence_acqrel_block() {
44-
asm!("fence.acq_rel.sys");
44+
asm!("fence.acq_rel.sys;");
4545
}
4646

4747
#[gpu_only]
4848
pub unsafe fn fence_acqrel_system() {
49-
asm!("fence.acq_rel.sys");
49+
asm!("fence.acq_rel.sys;");
50+
}
51+
52+
macro_rules! load_scope {
53+
(volatile, $scope:ident) => {
54+
""
55+
};
56+
($ordering:ident, $scope:ident) => {
57+
concat!(".", stringify!($scope))
58+
};
5059
}
5160

5261
macro_rules! load {
@@ -59,7 +68,7 @@ macro_rules! load {
5968
pub unsafe fn [<atomic_load_ $ordering _ $width _ $scope>](ptr: *const [<u $width>]) -> [<u $width>] {
6069
let mut out;
6170
asm!(
62-
concat!("ld.", stringify!($ordering), ".", stringify!($scope_asm), ".", stringify!([<u $width>]), "{}, [{}]"),
71+
concat!("ld.", stringify!($ordering), load_scope!($ordering, $scope), ".", stringify!([<u $width>]), " {}, [{}];"),
6372
out([<reg $width>]) out,
6473
in(reg64) ptr
6574
);
@@ -105,7 +114,7 @@ macro_rules! store {
105114
#[doc = concat!("Performs a ", stringify!($ordering), " atomic store at the ", stringify!($scope), " level with a width of ", stringify!($width), " bits")]
106115
pub unsafe fn [<atomic_store_ $ordering _ $width _ $scope>](ptr: *mut [<u $width>], val: [<u $width>]) {
107116
asm!(
108-
concat!("st.", stringify!($ordering), ".", stringify!($scope_asm), ".", stringify!([<u $width>]), "[{}], {}"),
117+
concat!("st.", stringify!($ordering), load_scope!($ordering, $scope), ".", stringify!([<u $width>]), " [{}], {};"),
109118
in(reg64) ptr,
110119
in([<reg $width>]) val,
111120
);
@@ -141,6 +150,19 @@ store! {
141150
volatile, 64, system, sys,
142151
}
143152

153+
#[allow(unused_macros)]
154+
macro_rules! ptx_type {
155+
(i32) => {
156+
"s32"
157+
};
158+
(i64) => {
159+
"s64"
160+
};
161+
($ty:ident) => {
162+
stringify!($ty)
163+
};
164+
}
165+
144166
#[allow(unused_macros)]
145167
macro_rules! ordering {
146168
(volatile) => {
@@ -172,7 +194,8 @@ macro_rules! atomic_fetch_op_2_reg {
172194
".",
173195
stringify!($op),
174196
".",
175-
"{}, [{}]"
197+
ptx_type!($type),
198+
" {}, [{}];"
176199
),
177200
out([<reg $width>]) out,
178201
in(reg64) ptr,
@@ -359,7 +382,8 @@ macro_rules! atomic_fetch_op_3_reg {
359382
".",
360383
stringify!($op),
361384
".",
362-
"{}, [{}], {}"
385+
ptx_type!($type),
386+
" {}, [{}], {};"
363387
),
364388
out([<reg $width>]) out,
365389
in(reg64) ptr,
@@ -1101,7 +1125,8 @@ macro_rules! atomic_fetch_op_4_reg {
11011125
".",
11021126
stringify!($op),
11031127
".",
1104-
"{}, [{}], {}, {}"
1128+
ptx_type!($type),
1129+
" {}, [{}], {}, {};"
11051130
),
11061131
out([<reg $width>]) out,
11071132
in(reg64) ptr,
@@ -1227,6 +1252,19 @@ atomic_fetch_op_4_reg! {
12271252
volatile, cas, 64, f64, system, sys,
12281253
}
12291254

1255+
#[allow(unused_macros)]
1256+
macro_rules! negation {
1257+
(u32, $val:ident) => {{
1258+
-($val as i32)
1259+
}};
1260+
(u64, $val:ident) => {{
1261+
-($val as i64)
1262+
}};
1263+
($type:ty, $val:ident) => {{
1264+
-$val
1265+
}};
1266+
}
1267+
12301268
// atomic sub is a little special, nvcc implements it as an atomic add with a negated operand. PTX
12311269
// does not have atom.sub.
12321270
macro_rules! atomic_sub {
@@ -1246,11 +1284,12 @@ macro_rules! atomic_sub {
12461284
".",
12471285
"add",
12481286
".",
1249-
"{}, [{}], {}"
1287+
ptx_type!($type),
1288+
" {}, [{}], {};"
12501289
),
12511290
out([<reg $width>]) out,
12521291
in(reg64) ptr,
1253-
in([<reg $width>]) -(val as [<i $width>]),
1292+
in([<reg $width>]) negation!($type, val),
12541293
);
12551294
out
12561295
}

0 commit comments

Comments
 (0)