@@ -57,6 +57,21 @@ impl PallasChainReader {
57
57
. with_context ( || "PallasChainReader failed to get a client" )
58
58
}
59
59
60
+ #[ cfg( test) ]
61
+ /// Check if the client already exists (test only).
62
+ fn has_client ( & self ) -> bool {
63
+ self . client . is_some ( )
64
+ }
65
+
66
+ /// Drops the client by aborting the connection and setting it to `None`.
67
+ fn drop_client ( & mut self ) {
68
+ if let Some ( client) = self . client . take ( ) {
69
+ tokio:: spawn ( async move {
70
+ let _ = client. abort ( ) . await ;
71
+ } ) ;
72
+ }
73
+ }
74
+
60
75
/// Intersects the point of the chain with the given point.
61
76
async fn find_intersect_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
62
77
let logger = self . logger . clone ( ) ;
@@ -99,30 +114,38 @@ impl PallasChainReader {
99
114
100
115
impl Drop for PallasChainReader {
101
116
fn drop ( & mut self ) {
102
- if let Some ( client) = self . client . take ( ) {
103
- tokio:: spawn ( async move {
104
- let _ = client. abort ( ) . await ;
105
- } ) ;
106
- }
117
+ self . drop_client ( ) ;
107
118
}
108
119
}
109
120
110
121
#[ async_trait]
111
122
impl ChainBlockReader for PallasChainReader {
112
123
async fn set_chain_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
113
- self . find_intersect_point ( point) . await
124
+ match self . find_intersect_point ( point) . await {
125
+ Ok ( ( ) ) => Ok ( ( ) ) ,
126
+ Err ( err) => {
127
+ self . drop_client ( ) ;
128
+
129
+ return Err ( err) ;
130
+ }
131
+ }
114
132
}
115
133
116
134
async fn get_next_chain_block ( & mut self ) -> StdResult < Option < ChainBlockNextAction > > {
117
135
let client = self . get_client ( ) . await ?;
118
136
let chainsync = client. chainsync ( ) ;
119
-
120
137
let next = match chainsync. has_agency ( ) {
121
- true => chainsync. request_next ( ) . await ? ,
122
- false => chainsync. recv_while_must_reply ( ) . await ? ,
138
+ true => chainsync. request_next ( ) . await ,
139
+ false => chainsync. recv_while_must_reply ( ) . await ,
123
140
} ;
141
+ match next {
142
+ Ok ( next) => self . process_chain_block_next_action ( next) . await ,
143
+ Err ( err) => {
144
+ self . drop_client ( ) ;
124
145
125
- self . process_chain_block_next_action ( next) . await
146
+ return Err ( err. into ( ) ) ;
147
+ }
148
+ }
126
149
}
127
150
}
128
151
@@ -201,7 +224,7 @@ mod tests {
201
224
socket_path : PathBuf ,
202
225
action : ServerAction ,
203
226
has_agency : HasAgency ,
204
- ) -> tokio:: task:: JoinHandle < ( ) > {
227
+ ) -> tokio:: task:: JoinHandle < NodeServer > {
205
228
tokio:: spawn ( {
206
229
async move {
207
230
if socket_path. exists ( ) {
@@ -249,6 +272,8 @@ mod tests {
249
272
. unwrap ( ) ;
250
273
}
251
274
}
275
+
276
+ server
252
277
}
253
278
} )
254
279
}
@@ -375,4 +400,58 @@ mod tests {
375
400
_ => panic ! ( "Unexpected chain block action" ) ,
376
401
}
377
402
}
403
+
404
+ #[ tokio:: test]
405
+ async fn cached_client_is_dropped_when_returning_error ( ) {
406
+ let socket_path =
407
+ create_temp_dir ( "cached_client_is_dropped_when_returning_error" ) . join ( "node.socket" ) ;
408
+ let socket_path_clone = socket_path. clone ( ) ;
409
+ let known_point = get_fake_specific_point ( ) ;
410
+ let server = setup_server (
411
+ socket_path. clone ( ) ,
412
+ ServerAction :: RollForward ,
413
+ HasAgency :: Yes ,
414
+ )
415
+ . await ;
416
+ let client = tokio:: spawn ( async move {
417
+ let mut chain_reader = PallasChainReader :: new (
418
+ socket_path_clone. as_path ( ) ,
419
+ CardanoNetwork :: TestNet ( 10 ) ,
420
+ TestLogger :: stdout ( ) ,
421
+ ) ;
422
+
423
+ chain_reader
424
+ . set_chain_point ( & RawCardanoPoint :: from ( known_point. clone ( ) ) )
425
+ . await
426
+ . unwrap ( ) ;
427
+
428
+ chain_reader. get_next_chain_block ( ) . await . unwrap ( ) . unwrap ( ) ;
429
+
430
+ chain_reader
431
+ } ) ;
432
+
433
+ let ( server_res, client_res) = tokio:: join!( server, client) ;
434
+ let chain_reader = client_res. expect ( "Client failed to get chain reader" ) ;
435
+ let server = server_res. expect ( "Server failed to get server" ) ;
436
+ server. abort ( ) . await ;
437
+
438
+ let client = tokio:: spawn ( async move {
439
+ let mut chain_reader = chain_reader;
440
+
441
+ assert ! ( chain_reader. has_client( ) , "Client should exist" ) ;
442
+
443
+ chain_reader
444
+ . get_next_chain_block ( )
445
+ . await
446
+ . expect_err ( "Chain reader get_next_chain_block should fail" ) ;
447
+
448
+ assert ! (
449
+ !chain_reader. has_client( ) ,
450
+ "Client should have been dropped after error"
451
+ ) ;
452
+
453
+ chain_reader
454
+ } ) ;
455
+ client. await . unwrap ( ) ;
456
+ }
378
457
}
0 commit comments