Skip to content

Commit 6e09dbc

Browse files
authored
Merge pull request #2501 from input-output-hk/jpraynaud/2426-fix-chain-reader-client-cache
Fix: chain reader client cache drop on error
2 parents c9834b2 + 5cbd4f8 commit 6e09dbc

File tree

3 files changed

+95
-17
lines changed

3 files changed

+95
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mithril-common/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "mithril-common"
3-
version = "0.5.31"
3+
version = "0.5.32"
44
description = "Common types, interfaces, and utilities for Mithril nodes."
55
authors = { workspace = true }
66
edition = { workspace = true }

mithril-common/src/chain_reader/pallas_chain_reader.rs

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ impl PallasChainReader {
5757
.with_context(|| "PallasChainReader failed to get a client")
5858
}
5959

60+
#[cfg(test)]
61+
/// Check if the client already exists (test only).
62+
fn has_client(&self) -> bool {
63+
self.client.is_some()
64+
}
65+
66+
/// Drops the client by aborting the connection and setting it to `None`.
67+
fn drop_client(&mut self) {
68+
if let Some(client) = self.client.take() {
69+
tokio::spawn(async move {
70+
let _ = client.abort().await;
71+
});
72+
}
73+
}
74+
6075
/// Intersects the point of the chain with the given point.
6176
async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
6277
let logger = self.logger.clone();
@@ -99,30 +114,38 @@ impl PallasChainReader {
99114

100115
impl Drop for PallasChainReader {
101116
fn drop(&mut self) {
102-
if let Some(client) = self.client.take() {
103-
tokio::spawn(async move {
104-
let _ = client.abort().await;
105-
});
106-
}
117+
self.drop_client();
107118
}
108119
}
109120

110121
#[async_trait]
111122
impl ChainBlockReader for PallasChainReader {
112123
async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
113-
self.find_intersect_point(point).await
124+
match self.find_intersect_point(point).await {
125+
Ok(()) => Ok(()),
126+
Err(err) => {
127+
self.drop_client();
128+
129+
return Err(err);
130+
}
131+
}
114132
}
115133

116134
async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
117135
let client = self.get_client().await?;
118136
let chainsync = client.chainsync();
119-
120137
let next = match chainsync.has_agency() {
121-
true => chainsync.request_next().await?,
122-
false => chainsync.recv_while_must_reply().await?,
138+
true => chainsync.request_next().await,
139+
false => chainsync.recv_while_must_reply().await,
123140
};
141+
match next {
142+
Ok(next) => self.process_chain_block_next_action(next).await,
143+
Err(err) => {
144+
self.drop_client();
124145

125-
self.process_chain_block_next_action(next).await
146+
return Err(err.into());
147+
}
148+
}
126149
}
127150
}
128151

@@ -142,6 +165,7 @@ mod tests {
142165
use super::*;
143166

144167
use crate::test_utils::TestLogger;
168+
use crate::*;
145169
use crate::{entities::BlockNumber, test_utils::TempDir};
146170

147171
/// Enum representing the action to be performed by the server.
@@ -201,7 +225,7 @@ mod tests {
201225
socket_path: PathBuf,
202226
action: ServerAction,
203227
has_agency: HasAgency,
204-
) -> tokio::task::JoinHandle<()> {
228+
) -> tokio::task::JoinHandle<NodeServer> {
205229
tokio::spawn({
206230
async move {
207231
if socket_path.exists() {
@@ -249,14 +273,15 @@ mod tests {
249273
.unwrap();
250274
}
251275
}
276+
277+
server
252278
}
253279
})
254280
}
255281

256282
#[tokio::test]
257283
async fn get_next_chain_block_rolls_backward() {
258-
let socket_path =
259-
create_temp_dir("get_next_chain_block_rolls_backward").join("node.socket");
284+
let socket_path = create_temp_dir(current_function!()).join("node.socket");
260285
let known_point = get_fake_specific_point();
261286
let server = setup_server(
262287
socket_path.clone(),
@@ -291,7 +316,7 @@ mod tests {
291316

292317
#[tokio::test]
293318
async fn get_next_chain_block_rolls_forward() {
294-
let socket_path = create_temp_dir("get_next_chain_block_rolls_forward").join("node.socket");
319+
let socket_path = create_temp_dir(current_function!()).join("node.socket");
295320
let known_point = get_fake_specific_point();
296321
let server = setup_server(
297322
socket_path.clone(),
@@ -326,7 +351,7 @@ mod tests {
326351

327352
#[tokio::test]
328353
async fn get_next_chain_block_has_no_agency() {
329-
let socket_path = create_temp_dir("get_next_chain_block_has_no_agency").join("node.socket");
354+
let socket_path = create_temp_dir(current_function!()).join("node.socket");
330355
let known_point = get_fake_specific_point();
331356
let server = setup_server(
332357
socket_path.clone(),
@@ -375,4 +400,57 @@ mod tests {
375400
_ => panic!("Unexpected chain block action"),
376401
}
377402
}
403+
404+
#[tokio::test]
405+
async fn cached_client_is_dropped_when_returning_error() {
406+
let socket_path = create_temp_dir(current_function!()).join("node.socket");
407+
let socket_path_clone = socket_path.clone();
408+
let known_point = get_fake_specific_point();
409+
let server = setup_server(
410+
socket_path.clone(),
411+
ServerAction::RollForward,
412+
HasAgency::Yes,
413+
)
414+
.await;
415+
let client = tokio::spawn(async move {
416+
let mut chain_reader = PallasChainReader::new(
417+
socket_path_clone.as_path(),
418+
CardanoNetwork::TestNet(10),
419+
TestLogger::stdout(),
420+
);
421+
422+
chain_reader
423+
.set_chain_point(&RawCardanoPoint::from(known_point.clone()))
424+
.await
425+
.unwrap();
426+
427+
chain_reader.get_next_chain_block().await.unwrap().unwrap();
428+
429+
chain_reader
430+
});
431+
432+
let (server_res, client_res) = tokio::join!(server, client);
433+
let chain_reader = client_res.expect("Client failed to get chain reader");
434+
let server = server_res.expect("Server failed to get server");
435+
server.abort().await;
436+
437+
let client = tokio::spawn(async move {
438+
let mut chain_reader = chain_reader;
439+
440+
assert!(chain_reader.has_client(), "Client should exist");
441+
442+
chain_reader
443+
.get_next_chain_block()
444+
.await
445+
.expect_err("Chain reader get_next_chain_block should fail");
446+
447+
assert!(
448+
!chain_reader.has_client(),
449+
"Client should have been dropped after error"
450+
);
451+
452+
chain_reader
453+
});
454+
client.await.unwrap();
455+
}
378456
}

0 commit comments

Comments
 (0)