@@ -57,6 +57,21 @@ impl PallasChainReader {
57
57
. with_context ( || "PallasChainReader failed to get a client" )
58
58
}
59
59
60
+ #[ cfg( test) ]
61
+ /// Check if the client already exists (test only).
62
+ fn has_client ( & self ) -> bool {
63
+ self . client . is_some ( )
64
+ }
65
+
66
+ /// Drops the client by aborting the connection and setting it to `None`.
67
+ fn drop_client ( & mut self ) {
68
+ if let Some ( client) = self . client . take ( ) {
69
+ tokio:: spawn ( async move {
70
+ let _ = client. abort ( ) . await ;
71
+ } ) ;
72
+ }
73
+ }
74
+
60
75
/// Intersects the point of the chain with the given point.
61
76
async fn find_intersect_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
62
77
let logger = self . logger . clone ( ) ;
@@ -99,30 +114,38 @@ impl PallasChainReader {
99
114
100
115
impl Drop for PallasChainReader {
101
116
fn drop ( & mut self ) {
102
- if let Some ( client) = self . client . take ( ) {
103
- tokio:: spawn ( async move {
104
- let _ = client. abort ( ) . await ;
105
- } ) ;
106
- }
117
+ self . drop_client ( ) ;
107
118
}
108
119
}
109
120
110
121
#[ async_trait]
111
122
impl ChainBlockReader for PallasChainReader {
112
123
async fn set_chain_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
113
- self . find_intersect_point ( point) . await
124
+ match self . find_intersect_point ( point) . await {
125
+ Ok ( ( ) ) => Ok ( ( ) ) ,
126
+ Err ( err) => {
127
+ self . drop_client ( ) ;
128
+
129
+ return Err ( err) ;
130
+ }
131
+ }
114
132
}
115
133
116
134
async fn get_next_chain_block ( & mut self ) -> StdResult < Option < ChainBlockNextAction > > {
117
135
let client = self . get_client ( ) . await ?;
118
136
let chainsync = client. chainsync ( ) ;
119
-
120
137
let next = match chainsync. has_agency ( ) {
121
- true => chainsync. request_next ( ) . await ? ,
122
- false => chainsync. recv_while_must_reply ( ) . await ? ,
138
+ true => chainsync. request_next ( ) . await ,
139
+ false => chainsync. recv_while_must_reply ( ) . await ,
123
140
} ;
141
+ match next {
142
+ Ok ( next) => self . process_chain_block_next_action ( next) . await ,
143
+ Err ( err) => {
144
+ self . drop_client ( ) ;
124
145
125
- self . process_chain_block_next_action ( next) . await
146
+ return Err ( err. into ( ) ) ;
147
+ }
148
+ }
126
149
}
127
150
}
128
151
@@ -142,6 +165,7 @@ mod tests {
142
165
use super :: * ;
143
166
144
167
use crate :: test_utils:: TestLogger ;
168
+ use crate :: * ;
145
169
use crate :: { entities:: BlockNumber , test_utils:: TempDir } ;
146
170
147
171
/// Enum representing the action to be performed by the server.
@@ -201,7 +225,7 @@ mod tests {
201
225
socket_path : PathBuf ,
202
226
action : ServerAction ,
203
227
has_agency : HasAgency ,
204
- ) -> tokio:: task:: JoinHandle < ( ) > {
228
+ ) -> tokio:: task:: JoinHandle < NodeServer > {
205
229
tokio:: spawn ( {
206
230
async move {
207
231
if socket_path. exists ( ) {
@@ -249,14 +273,15 @@ mod tests {
249
273
. unwrap ( ) ;
250
274
}
251
275
}
276
+
277
+ server
252
278
}
253
279
} )
254
280
}
255
281
256
282
#[ tokio:: test]
257
283
async fn get_next_chain_block_rolls_backward ( ) {
258
- let socket_path =
259
- create_temp_dir ( "get_next_chain_block_rolls_backward" ) . join ( "node.socket" ) ;
284
+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
260
285
let known_point = get_fake_specific_point ( ) ;
261
286
let server = setup_server (
262
287
socket_path. clone ( ) ,
@@ -291,7 +316,7 @@ mod tests {
291
316
292
317
#[ tokio:: test]
293
318
async fn get_next_chain_block_rolls_forward ( ) {
294
- let socket_path = create_temp_dir ( "get_next_chain_block_rolls_forward" ) . join ( "node.socket" ) ;
319
+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
295
320
let known_point = get_fake_specific_point ( ) ;
296
321
let server = setup_server (
297
322
socket_path. clone ( ) ,
@@ -326,7 +351,7 @@ mod tests {
326
351
327
352
#[ tokio:: test]
328
353
async fn get_next_chain_block_has_no_agency ( ) {
329
- let socket_path = create_temp_dir ( "get_next_chain_block_has_no_agency" ) . join ( "node.socket" ) ;
354
+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
330
355
let known_point = get_fake_specific_point ( ) ;
331
356
let server = setup_server (
332
357
socket_path. clone ( ) ,
@@ -375,4 +400,57 @@ mod tests {
375
400
_ => panic ! ( "Unexpected chain block action" ) ,
376
401
}
377
402
}
403
+
404
+ #[ tokio:: test]
405
+ async fn cached_client_is_dropped_when_returning_error ( ) {
406
+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
407
+ let socket_path_clone = socket_path. clone ( ) ;
408
+ let known_point = get_fake_specific_point ( ) ;
409
+ let server = setup_server (
410
+ socket_path. clone ( ) ,
411
+ ServerAction :: RollForward ,
412
+ HasAgency :: Yes ,
413
+ )
414
+ . await ;
415
+ let client = tokio:: spawn ( async move {
416
+ let mut chain_reader = PallasChainReader :: new (
417
+ socket_path_clone. as_path ( ) ,
418
+ CardanoNetwork :: TestNet ( 10 ) ,
419
+ TestLogger :: stdout ( ) ,
420
+ ) ;
421
+
422
+ chain_reader
423
+ . set_chain_point ( & RawCardanoPoint :: from ( known_point. clone ( ) ) )
424
+ . await
425
+ . unwrap ( ) ;
426
+
427
+ chain_reader. get_next_chain_block ( ) . await . unwrap ( ) . unwrap ( ) ;
428
+
429
+ chain_reader
430
+ } ) ;
431
+
432
+ let ( server_res, client_res) = tokio:: join!( server, client) ;
433
+ let chain_reader = client_res. expect ( "Client failed to get chain reader" ) ;
434
+ let server = server_res. expect ( "Server failed to get server" ) ;
435
+ server. abort ( ) . await ;
436
+
437
+ let client = tokio:: spawn ( async move {
438
+ let mut chain_reader = chain_reader;
439
+
440
+ assert ! ( chain_reader. has_client( ) , "Client should exist" ) ;
441
+
442
+ chain_reader
443
+ . get_next_chain_block ( )
444
+ . await
445
+ . expect_err ( "Chain reader get_next_chain_block should fail" ) ;
446
+
447
+ assert ! (
448
+ !chain_reader. has_client( ) ,
449
+ "Client should have been dropped after error"
450
+ ) ;
451
+
452
+ chain_reader
453
+ } ) ;
454
+ client. await . unwrap ( ) ;
455
+ }
378
456
}
0 commit comments