Skip to content

Commit 9c25ebc

Browse files
authored
Merge pull request #11 from asder8215/using_cells
Using cells
2 parents 7b956d2 + a670d2c commit 9c25ebc

File tree

1 file changed

+46
-78
lines changed

1 file changed

+46
-78
lines changed

src/shardedringbuf.rs

Lines changed: 46 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
use crossbeam_utils::CachePadded;
88
use fastrand::usize as frand;
99
use std::{
10-
cell::UnsafeCell,
10+
cell::{Cell, UnsafeCell},
1111
fmt::{Debug, Write},
1212
mem::MaybeUninit,
1313
ptr,
@@ -60,9 +60,9 @@ struct InnerRingBuffer<T> {
6060
/// Box containing the content of the buffer
6161
items: Box<[UnsafeCell<MaybeUninit<T>>]>,
6262
/// Where to enqueue at in the Box
63-
enqueue_index: AtomicUsize,
63+
enqueue_index: Cell<usize>,
6464
/// Where to dequeue at in the Box
65-
dequeue_index: AtomicUsize,
65+
dequeue_index: Cell<usize>,
6666
}
6767

6868
/// Implements the InnerRingBuffer functions
@@ -78,17 +78,17 @@ impl<T> InnerRingBuffer<T> {
7878
}
7979
vec.into_boxed_slice()
8080
},
81-
enqueue_index: AtomicUsize::new(0),
82-
dequeue_index: AtomicUsize::new(0),
81+
enqueue_index: Cell::new(0),
82+
dequeue_index: Cell::new(0),
8383
}
8484
}
8585

8686
/// Helper function to see if a given index inside this buffer does
8787
/// indeed contain a valid item. Used in Drop Trait.
8888
#[inline(always)]
8989
fn is_item_in_shard(&self, item_ind: usize) -> bool {
90-
let enqueue_ind = self.enqueue_index.load(Ordering::Relaxed) % self.items.len();
91-
let dequeue_ind = self.dequeue_index.load(Ordering::Relaxed) % self.items.len();
90+
let enqueue_ind = self.enqueue_index.get() % self.items.len();
91+
let dequeue_ind = self.dequeue_index.get() % self.items.len();
9292

9393
if enqueue_ind > dequeue_ind {
9494
item_ind < enqueue_ind && item_ind >= dequeue_ind
@@ -371,10 +371,9 @@ impl<T> ShardedRingBuf<T> {
371371
#[inline(always)]
372372
fn enqueue_in_shard(&self, shard_ind: usize, item: T) {
373373
let inner = &self.inner_rb[shard_ind];
374-
// we use fetch add here because we want to obtain the previous value
375-
// to dequeue while also incrementing this counter (separate load and store
376-
// incurs more cost)
377-
let enqueue_index = inner.enqueue_index.fetch_add(1, Ordering::Relaxed) % inner.items.len();
374+
375+
let enqueue_index = inner.enqueue_index.get() % inner.items.len();
376+
inner.enqueue_index.set(inner.enqueue_index.get() + 1);
378377
let item_cell = inner.items[enqueue_index].get();
379378
// SAFETY: Only one thread will perform this operation and write to this
380379
// item cell
@@ -416,13 +415,13 @@ impl<T> ShardedRingBuf<T> {
416415
}
417416
loop {
418417
let inner = &self.inner_rb[0];
419-
let enq_counter = inner.enqueue_index.load(Ordering::Relaxed);
420-
let deq_counter = inner.dequeue_index.load(Ordering::Relaxed);
418+
let enq_counter = inner.enqueue_index.get();
419+
let deq_counter = inner.dequeue_index.get();
421420
let jobs = enq_counter.wrapping_sub(deq_counter);
422421

423422
if jobs != inner.items.len() {
424-
let enqueue_index =
425-
inner.enqueue_index.fetch_add(1, Ordering::Relaxed) % inner.items.len();
423+
let enqueue_index = inner.enqueue_index.get() % inner.items.len();
424+
inner.enqueue_index.set(inner.enqueue_index.get() + 1);
426425
let item_cell = inner.items[enqueue_index].get();
427426
unsafe {
428427
(*item_cell).write(item);
@@ -459,7 +458,8 @@ impl<T> ShardedRingBuf<T> {
459458
// we use fetch add here because we want to obtain the previous value
460459
// to dequeue while also incrementing this counter (separate load and store
461460
// incurs more cost)
462-
let dequeue_index = inner.dequeue_index.fetch_add(1, Ordering::Relaxed) % inner.items.len();
461+
let dequeue_index = inner.dequeue_index.get() % inner.items.len();
462+
inner.dequeue_index.set(inner.dequeue_index.get() + 1);
463463

464464
let item_cell = inner.items[dequeue_index].get();
465465

@@ -505,14 +505,14 @@ impl<T> ShardedRingBuf<T> {
505505
}
506506

507507
let inner = &self.inner_rb[0];
508-
let enq_counter = inner.enqueue_index.load(Ordering::Relaxed);
509-
let deq_counter = inner.dequeue_index.load(Ordering::Relaxed);
508+
let enq_counter = inner.enqueue_index.get();
509+
let deq_counter = inner.dequeue_index.get();
510510

511511
let jobs = enq_counter.wrapping_sub(deq_counter);
512512

513513
if jobs != 0 {
514-
let dequeue_index =
515-
inner.dequeue_index.fetch_add(1, Ordering::Relaxed) % inner.items.len();
514+
let dequeue_index = inner.dequeue_index.get() % inner.items.len();
515+
inner.dequeue_index.set(inner.dequeue_index.get() + 1);
516516
// SAFETY: Only one thread will claim this slot and perform this operation
517517
// And it's guaranteed that an item will exist here
518518
let item = unsafe { (*inner.items[dequeue_index].get()).assume_init_read() };
@@ -574,13 +574,13 @@ impl<T> ShardedRingBuf<T> {
574574
if self.poisoned.load(Ordering::Relaxed) && self.is_empty() {
575575
return None;
576576
}
577-
let enq_counter = inner.enqueue_index.load(Ordering::Relaxed);
578-
let deq_counter = inner.dequeue_index.load(Ordering::Relaxed);
577+
let enq_counter = inner.enqueue_index.get();
578+
let deq_counter = inner.dequeue_index.get();
579579
let jobs = enq_counter.wrapping_sub(deq_counter);
580580

581581
if jobs != 0 {
582-
let dequeue_index =
583-
inner.dequeue_index.fetch_add(1, Ordering::Relaxed) % inner.items.len();
582+
let dequeue_index = inner.dequeue_index.get() % inner.items.len();
583+
inner.dequeue_index.set(inner.dequeue_index.get() + 1);
584584
// SAFETY: Only one thread will claim this slot and perform this operation
585585
// And it's guaranteed that an item will exist here
586586
let item =
@@ -686,8 +686,8 @@ impl<T> ShardedRingBuf<T> {
686686
// reset each shard's inner ring buffer
687687
for shard in 0..self.shard_locks.len() {
688688
let inner = &self.inner_rb[shard];
689-
let mut drop_index = inner.dequeue_index.load(Ordering::Relaxed) % inner.items.len();
690-
let stop_index = inner.enqueue_index.load(Ordering::Relaxed) % inner.items.len();
689+
let mut drop_index = inner.dequeue_index.get() % inner.items.len();
690+
let stop_index = inner.enqueue_index.get() % inner.items.len();
691691
while drop_index != stop_index {
692692
// SAFETY: This will only clear out initialized values that have not
693693
// been dequeued. Note here that this method uses Relaxed loads.
@@ -696,12 +696,8 @@ impl<T> ShardedRingBuf<T> {
696696
}
697697
drop_index = (drop_index + 1) % self.inner_rb[shard].items.len();
698698
}
699-
self.inner_rb[shard]
700-
.enqueue_index
701-
.store(0, Ordering::Relaxed);
702-
self.inner_rb[shard]
703-
.dequeue_index
704-
.store(0, Ordering::Relaxed);
699+
self.inner_rb[shard].enqueue_index.set(0);
700+
self.inner_rb[shard].dequeue_index.set(0);
705701
}
706702
}
707703

@@ -727,8 +723,8 @@ impl<T> ShardedRingBuf<T> {
727723
// reset each shard's inner ring buffer
728724
for (shard_ind, _guard) in guards.into_iter().enumerate() {
729725
let inner = &self.inner_rb[shard_ind];
730-
let mut drop_index = inner.dequeue_index.load(Ordering::Acquire) % inner.items.len();
731-
let stop_index = inner.enqueue_index.load(Ordering::Acquire) % inner.items.len();
726+
let mut drop_index = inner.dequeue_index.get() % inner.items.len();
727+
let stop_index = inner.enqueue_index.get() % inner.items.len();
732728
while drop_index != stop_index {
733729
// SAFETY: This will only clear out initialized values that have not
734730
// been dequeued.
@@ -739,12 +735,8 @@ impl<T> ShardedRingBuf<T> {
739735
}
740736
drop_index = (drop_index + 1) % self.inner_rb[shard_ind].items.len();
741737
}
742-
self.inner_rb[shard_ind]
743-
.enqueue_index
744-
.store(0, Ordering::Release);
745-
self.inner_rb[shard_ind]
746-
.dequeue_index
747-
.store(0, Ordering::Release);
738+
self.inner_rb[shard_ind].enqueue_index.set(0);
739+
self.inner_rb[shard_ind].dequeue_index.set(0);
748740
}
749741
}
750742

@@ -798,10 +790,7 @@ impl<T> ShardedRingBuf<T> {
798790
pub fn is_shard_empty(&self, shard_ind: usize) -> bool {
799791
let inner = &self.inner_rb[shard_ind];
800792
// use these values as monotonic counter than indices
801-
let (enq_ind, deq_ind) = (
802-
inner.enqueue_index.load(Ordering::Relaxed),
803-
inner.dequeue_index.load(Ordering::Relaxed),
804-
);
793+
let (enq_ind, deq_ind) = (inner.enqueue_index.get(), inner.dequeue_index.get());
805794
let jobs = enq_ind.wrapping_sub(deq_ind);
806795
jobs == 0
807796
}
@@ -818,10 +807,7 @@ impl<T> ShardedRingBuf<T> {
818807

819808
let inner = &self.inner_rb[shard_ind];
820809
// use these values as monotonic counter than indices
821-
let (enq_ind, deq_ind) = (
822-
inner.enqueue_index.load(Ordering::Relaxed),
823-
inner.dequeue_index.load(Ordering::Relaxed),
824-
);
810+
let (enq_ind, deq_ind) = (inner.enqueue_index.get(), inner.dequeue_index.get());
825811
let jobs = enq_ind.wrapping_sub(deq_ind);
826812
jobs == 0
827813
}
@@ -879,10 +865,7 @@ impl<T> ShardedRingBuf<T> {
879865
let inner = &self.inner_rb[shard_ind];
880866
let item_len = inner.items.len();
881867
// use these values as monotonic counter than indices
882-
let (enq_ind, deq_ind) = (
883-
inner.enqueue_index.load(Ordering::Relaxed),
884-
inner.dequeue_index.load(Ordering::Relaxed),
885-
);
868+
let (enq_ind, deq_ind) = (inner.enqueue_index.get(), inner.dequeue_index.get());
886869
let jobs = enq_ind.wrapping_sub(deq_ind);
887870
jobs == item_len
888871
}
@@ -900,10 +883,7 @@ impl<T> ShardedRingBuf<T> {
900883
let inner = &self.inner_rb[shard_ind];
901884
let item_len = inner.items.len();
902885
// use these values as monotonic counter than indices
903-
let (enq_ind, deq_ind) = (
904-
inner.enqueue_index.load(Ordering::Relaxed),
905-
inner.dequeue_index.load(Ordering::Relaxed),
906-
);
886+
let (enq_ind, deq_ind) = (inner.enqueue_index.get(), inner.dequeue_index.get());
907887
let jobs = enq_ind.wrapping_sub(deq_ind);
908888
jobs == item_len
909889
}
@@ -925,7 +905,7 @@ impl<T> ShardedRingBuf<T> {
925905

926906
// grab enq val
927907
let inner = &self.inner_rb[shard_ind];
928-
let enq_ind = inner.enqueue_index.load(Ordering::Relaxed) % inner.items.len();
908+
let enq_ind = inner.enqueue_index.get() % inner.items.len();
929909

930910
Some(enq_ind)
931911
}
@@ -947,7 +927,7 @@ impl<T> ShardedRingBuf<T> {
947927

948928
// grab enq val
949929
let inner = &self.inner_rb[shard_ind];
950-
let enq_ind = inner.enqueue_index.load(Ordering::Relaxed) % inner.items.len();
930+
let enq_ind = inner.enqueue_index.get() % inner.items.len();
951931

952932
Some(enq_ind)
953933
}
@@ -969,7 +949,7 @@ impl<T> ShardedRingBuf<T> {
969949

970950
// grab deq ind val
971951
let inner = &self.inner_rb[shard_ind];
972-
let deq_ind = inner.dequeue_index.load(Ordering::Relaxed) % inner.items.len();
952+
let deq_ind = inner.dequeue_index.get() % inner.items.len();
973953

974954
Some(deq_ind)
975955
}
@@ -991,7 +971,7 @@ impl<T> ShardedRingBuf<T> {
991971

992972
// grab deq ind val
993973
let inner = &self.inner_rb[shard_ind];
994-
let deq_ind = inner.dequeue_index.load(Ordering::Relaxed) % inner.items.len();
974+
let deq_ind = inner.dequeue_index.get() % inner.items.len();
995975

996976
Some(deq_ind)
997977
}
@@ -1009,10 +989,7 @@ impl<T> ShardedRingBuf<T> {
1009989
}
1010990

1011991
let inner = &self.inner_rb[shard_ind];
1012-
let (enq_count, deq_count) = (
1013-
inner.enqueue_index.load(Ordering::Relaxed),
1014-
inner.dequeue_index.load(Ordering::Relaxed),
1015-
);
992+
let (enq_count, deq_count) = (inner.enqueue_index.get(), inner.dequeue_index.get());
1016993
let jobs = enq_count.wrapping_sub(deq_count);
1017994
Some(jobs)
1018995
}
@@ -1033,10 +1010,7 @@ impl<T> ShardedRingBuf<T> {
10331010
ShardLockGuard::acquire(&self.shard_locks[shard_ind]).await;
10341011

10351012
let inner = &self.inner_rb[shard_ind];
1036-
let (enq_count, deq_count) = (
1037-
inner.enqueue_index.load(Ordering::Relaxed),
1038-
inner.dequeue_index.load(Ordering::Relaxed),
1039-
);
1013+
let (enq_count, deq_count) = (inner.enqueue_index.get(), inner.dequeue_index.get());
10401014
let jobs = enq_count.wrapping_sub(deq_count);
10411015
Some(jobs)
10421016
}
@@ -1051,10 +1025,7 @@ impl<T> ShardedRingBuf<T> {
10511025
let mut count = Vec::new();
10521026

10531027
for shard in &self.inner_rb {
1054-
let (enq_count, deq_count) = (
1055-
shard.enqueue_index.load(Ordering::Relaxed),
1056-
shard.dequeue_index.load(Ordering::Relaxed),
1057-
);
1028+
let (enq_count, deq_count) = (shard.enqueue_index.get(), shard.dequeue_index.get());
10581029
let jobs = enq_count.wrapping_sub(deq_count);
10591030
count.push(jobs);
10601031
}
@@ -1081,10 +1052,7 @@ impl<T> ShardedRingBuf<T> {
10811052
// guard for me when it goes to the next iteration
10821053
for (shard_ind, _guard) in guards.into_iter().enumerate() {
10831054
let shard = &self.inner_rb[shard_ind];
1084-
let (enq_count, deq_count) = (
1085-
shard.enqueue_index.load(Ordering::Relaxed),
1086-
shard.dequeue_index.load(Ordering::Relaxed),
1087-
);
1055+
let (enq_count, deq_count) = (shard.enqueue_index.get(), shard.dequeue_index.get());
10881056
let jobs = enq_count.wrapping_sub(deq_count);
10891057
count.push(jobs);
10901058
}
@@ -1096,8 +1064,8 @@ impl<T> ShardedRingBuf<T> {
10961064
#[inline(always)]
10971065
fn is_item_in_shard(&self, item_ind: usize, shard_ind: usize) -> bool {
10981066
let inner = &self.inner_rb[shard_ind];
1099-
let enqueue_ind = inner.enqueue_index.load(Ordering::Relaxed) % inner.items.len();
1100-
let dequeue_ind = inner.dequeue_index.load(Ordering::Relaxed) % inner.items.len();
1067+
let enqueue_ind = inner.enqueue_index.get() % inner.items.len();
1068+
let dequeue_ind = inner.dequeue_index.get() % inner.items.len();
11011069

11021070
if enqueue_ind > dequeue_ind {
11031071
item_ind < enqueue_ind && item_ind >= dequeue_ind

0 commit comments

Comments
 (0)