Skip to content

Commit 9ba7c16

Browse files
committed
Make TestStore async writes actually async, with manual complete
`TestStore` recently got the ability to make async writes, but wasn't a very useful test as all writes actually completed immediately. Instead, here, we make the writes actually-async, forcing the test to mark writes complete as required.
1 parent 6199bcb commit 9ba7c16

File tree

1 file changed

+124
-21
lines changed

1 file changed

+124
-21
lines changed

lightning/src/util/test_utils.rs

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ use core::future::Future;
8989
use core::mem;
9090
use core::pin::Pin;
9191
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
92+
use core::task::{Context, Poll, Waker};
9293
use core::time::Duration;
9394

9495
use bitcoin::psbt::Psbt;
@@ -856,15 +857,93 @@ impl<Signer: sign::ecdsa::EcdsaChannelSigner> Persist<Signer> for TestPersister
856857
}
857858
}
858859

860+
// A simple multi-producer-single-consumer one-shot channel
861+
type OneShotChannelState = Arc<Mutex<(Option<Result<(), io::Error>>, Option<Waker>)>>;
862+
struct OneShotChannel(OneShotChannelState);
863+
impl Future for OneShotChannel {
864+
type Output = Result<(), io::Error>;
865+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
866+
let mut state = self.0.lock().unwrap();
867+
// If the future is complete, take() the result and return it,
868+
state.0.take().map(|res| Poll::Ready(res)).unwrap_or_else(|| {
869+
// otherwise, store the waker so that the future will be poll()ed again when the result
870+
// is ready.
871+
state.1 = Some(cx.waker().clone());
872+
Poll::Pending
873+
})
874+
}
875+
}
876+
877+
/// An in-memory KVStore for testing.
878+
///
879+
/// Sync writes always complete immediately while async writes always block until manually
880+
/// completed with [`Self::complete_async_writes_through`] or [`Self::complete_all_async_writes`].
881+
///
882+
/// Removes always complete immediately.
859883
pub struct TestStore {
884+
pending_async_writes: Mutex<HashMap<String, Vec<(usize, OneShotChannelState, Vec<u8>)>>>,
860885
persisted_bytes: Mutex<HashMap<String, HashMap<String, Vec<u8>>>>,
861886
read_only: bool,
862887
}
863888

864889
impl TestStore {
865890
pub fn new(read_only: bool) -> Self {
891+
let pending_async_writes = Mutex::new(new_hash_map());
866892
let persisted_bytes = Mutex::new(new_hash_map());
867-
Self { persisted_bytes, read_only }
893+
Self { pending_async_writes, persisted_bytes, read_only }
894+
}
895+
896+
pub fn list_pending_async_writes(
897+
&self, primary_namespace: &str, secondary_namespace: &str, key: &str,
898+
) -> Vec<usize> {
899+
let key = format!("{primary_namespace}/{secondary_namespace}/{key}");
900+
let writes_lock = self.pending_async_writes.lock().unwrap();
901+
writes_lock
902+
.get(&key)
903+
.map(|v| v.iter().map(|(id, _, _)| *id).collect())
904+
.unwrap_or(Vec::new())
905+
}
906+
907+
/// Completes all pending async writes for the given namespace and key, up to and through the
908+
/// given `write_id` (which can be fetched from [`Self::list_pending_async_writes`]).
909+
pub fn complete_async_writes_through(
910+
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, write_id: usize,
911+
) {
912+
let prefix = format!("{primary_namespace}/{secondary_namespace}");
913+
let key = format!("{primary_namespace}/{secondary_namespace}/{key}");
914+
915+
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
916+
let mut writes_lock = self.pending_async_writes.lock().unwrap();
917+
918+
let pending_writes = writes_lock.get_mut(&key).expect("No pending writes for given key");
919+
pending_writes.retain(|(id, res, data)| {
920+
if *id <= write_id {
921+
let namespace = persisted_lock.entry(prefix.clone()).or_insert(new_hash_map());
922+
*namespace.entry(key.to_string()).or_default() = data.clone();
923+
let mut future_state = res.lock().unwrap();
924+
future_state.0 = Some(Ok(()));
925+
if let Some(waker) = future_state.1.take() {
926+
waker.wake();
927+
}
928+
false
929+
} else {
930+
true
931+
}
932+
});
933+
}
934+
935+
/// Completes all pending async writes on all namespaces and keys.
936+
pub fn complete_all_async_writes(&self) {
937+
let pending_writes: Vec<String> =
938+
self.pending_async_writes.lock().unwrap().keys().cloned().collect();
939+
for key in pending_writes {
940+
let mut levels = key.split("/");
941+
let primary = levels.next().unwrap();
942+
let secondary = levels.next().unwrap();
943+
let key = levels.next().unwrap();
944+
assert!(levels.next().is_none());
945+
self.complete_async_writes_through(primary, secondary, key, usize::MAX);
946+
}
868947
}
869948

870949
fn read_internal(
@@ -885,23 +964,6 @@ impl TestStore {
885964
}
886965
}
887966

888-
fn write_internal(
889-
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
890-
) -> io::Result<()> {
891-
if self.read_only {
892-
return Err(io::Error::new(
893-
io::ErrorKind::PermissionDenied,
894-
"Cannot modify read-only store",
895-
));
896-
}
897-
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
898-
899-
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
900-
let outer_e = persisted_lock.entry(prefixed).or_insert(new_hash_map());
901-
outer_e.insert(key.to_string(), buf);
902-
Ok(())
903-
}
904-
905967
fn remove_internal(
906968
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool,
907969
) -> io::Result<()> {
@@ -913,12 +975,23 @@ impl TestStore {
913975
}
914976

915977
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
978+
let mut async_writes_lock = self.pending_async_writes.lock().unwrap();
916979

917980
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
918981
if let Some(outer_ref) = persisted_lock.get_mut(&prefixed) {
919982
outer_ref.remove(&key.to_string());
920983
}
921984

985+
if let Some(pending_writes) = async_writes_lock.remove(&format!("{prefixed}/{key}")) {
986+
for (_, future, _) in pending_writes {
987+
let mut future_lock = future.lock().unwrap();
988+
future_lock.0 = Some(Ok(()));
989+
if let Some(waker) = future_lock.1.take() {
990+
waker.wake();
991+
}
992+
}
993+
}
994+
922995
Ok(())
923996
}
924997

@@ -945,8 +1018,15 @@ impl KVStore for TestStore {
9451018
fn write(
9461019
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
9471020
) -> Pin<Box<dyn Future<Output = Result<(), io::Error>> + 'static + Send>> {
948-
let res = self.write_internal(&primary_namespace, &secondary_namespace, &key, buf);
949-
Box::pin(async move { res })
1021+
let path = format!("{primary_namespace}/{secondary_namespace}/{key}");
1022+
let future = Arc::new(Mutex::new((None, None)));
1023+
1024+
let mut async_writes_lock = self.pending_async_writes.lock().unwrap();
1025+
let pending_writes = async_writes_lock.entry(path).or_insert(Vec::new());
1026+
let new_id = pending_writes.last().map(|(id, _, _)| id + 1).unwrap_or(0);
1027+
pending_writes.push((new_id, Arc::clone(&future), buf));
1028+
1029+
Box::pin(OneShotChannel(future))
9501030
}
9511031
fn remove(
9521032
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool,
@@ -972,7 +1052,30 @@ impl KVStoreSync for TestStore {
9721052
fn write(
9731053
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
9741054
) -> io::Result<()> {
975-
self.write_internal(primary_namespace, secondary_namespace, key, buf)
1055+
if self.read_only {
1056+
return Err(io::Error::new(
1057+
io::ErrorKind::PermissionDenied,
1058+
"Cannot modify read-only store",
1059+
));
1060+
}
1061+
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
1062+
let mut async_writes_lock = self.pending_async_writes.lock().unwrap();
1063+
1064+
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
1065+
let async_writes_pending = async_writes_lock.remove(&format!("{prefixed}/{key}"));
1066+
let outer_e = persisted_lock.entry(prefixed).or_insert(new_hash_map());
1067+
outer_e.insert(key.to_string(), buf);
1068+
1069+
if let Some(pending_writes) = async_writes_pending {
1070+
for (_, future, _) in pending_writes {
1071+
let mut future_lock = future.lock().unwrap();
1072+
future_lock.0 = Some(Ok(()));
1073+
if let Some(waker) = future_lock.1.take() {
1074+
waker.wake();
1075+
}
1076+
}
1077+
}
1078+
Ok(())
9761079
}
9771080

9781081
fn remove(

0 commit comments

Comments
 (0)