Skip to content

Commit ccff588

Browse files
committed
feat(socket): fix some stuff, add blocking connect to ReqSocket
1 parent f63565a commit ccff588

File tree

4 files changed

+87
-5
lines changed

4 files changed

+87
-5
lines changed

msg-socket/src/req/driver.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use tokio::{
1515
time::Interval,
1616
};
1717
use tokio_util::codec::Framed;
18-
use tracing::{debug, error, trace};
18+
use tracing::{debug, error, trace, warn};
1919

2020
use super::{Command, ReqError, ReqOptions};
2121
use crate::{ConnectionState, ExponentialBackoff, req::SocketState};
@@ -327,7 +327,7 @@ where
327327
}
328328
Poll::Ready(Some(Err(err))) => {
329329
if let reqrep::Error::Io(e) = err {
330-
error!(err = ?e, "Socket wire error");
330+
error!(err = ?e, "wire error, resetting connection state");
331331
}
332332

333333
// set the connection to inactive, so that it will be re-tried
@@ -336,9 +336,12 @@ where
336336
continue;
337337
}
338338
Poll::Ready(None) => {
339-
debug!("Connection to {:?} closed, shutting down driver", this.addr);
339+
warn!(peer = ?this.addr, "connection closed, resetting connection state");
340340

341-
return Poll::Ready(());
341+
// set the connection to inactive, so that it will be re-tried
342+
this.reset_connection();
343+
344+
continue;
342345
}
343346
Poll::Pending => {}
344347
}

msg-socket/src/req/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ pub enum ReqError {
3535
Timeout,
3636
#[error("Could not connect to any valid endpoints")]
3737
NoValidEndpoints,
38+
#[error("Failed to connect to the target endpoint: {0:?}")]
39+
Connect(Box<dyn std::error::Error + Send + Sync>),
3840
}
3941

4042
/// Commands that can be sent to the request socket driver.

msg-socket/src/req/socket.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ where
115115
// Initialize communication channels
116116
let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE);
117117

118-
let transport = self.transport.take().expect("Transport has been moved already");
118+
// TODO: Don't panic, return error
119+
let mut transport = self.transport.take().expect("Transport has been moved already");
119120

120121
// We initialize the connection as inactive, and let it be activated
121122
// by the backend task as soon as the driver is spawned.
@@ -124,6 +125,13 @@ where
124125
backoff: ExponentialBackoff::new(Duration::from_millis(20), 16),
125126
};
126127

128+
if self.options.blocking_connect {
129+
transport
130+
.connect(endpoint.clone())
131+
.await
132+
.map_err(|e| ReqError::Connect(Box::new(e)))?;
133+
}
134+
127135
let timeout_check_interval = tokio::time::interval(self.options.timeout / 10);
128136

129137
let flush_interval = self.options.flush_interval.map(tokio::time::interval);

msg-socket/tests/it/reqrep.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::time::Duration;
2+
13
use bytes::Bytes;
24
use msg_socket::{RepSocket, ReqSocket};
35
use msg_transport::{
@@ -144,3 +146,70 @@ async fn reqrep_mutual_tls_works() {
144146
let response = req.request(hello.clone()).await.unwrap();
145147
assert_eq!(hello, response, "expected {:?}, got {:?}", hello, response);
146148
}
149+
150+
#[tokio::test]
151+
async fn reqrep_late_bind_works() {
152+
let _ = tracing_subscriber::fmt::try_init();
153+
154+
let mut rep = RepSocket::new(Tcp::default());
155+
let mut req = ReqSocket::new(Tcp::default());
156+
157+
let local_addr = "localhost:64521";
158+
req.connect(local_addr).await.unwrap();
159+
160+
let hello = Bytes::from_static(b"hello");
161+
162+
let reply = tokio::spawn(async move { req.request(hello.clone()).await.unwrap() });
163+
164+
tokio::time::sleep(Duration::from_millis(1000)).await;
165+
rep.bind(local_addr).await.unwrap();
166+
167+
let msg = rep.next().await.unwrap();
168+
let payload = msg.msg().clone();
169+
msg.respond(payload).unwrap();
170+
171+
let response = reply.await.unwrap();
172+
let hello = Bytes::from_static(b"hello");
173+
assert_eq!(hello, response, "expected {:?}, got {:?}", hello, response);
174+
}
175+
176+
#[tokio::test]
177+
async fn reqrep_drop_server() {
178+
let _ = tracing_subscriber::fmt::try_init();
179+
180+
let mut rep = RepSocket::new(Tcp::default());
181+
let mut req = ReqSocket::new(Tcp::default());
182+
183+
rep.bind("0.0.0.0:0").await.unwrap();
184+
185+
let addr = rep.local_addr().unwrap().clone();
186+
req.connect(addr).await.unwrap();
187+
188+
tokio::spawn(async move {
189+
let request = rep.next().await.unwrap();
190+
let msg = request.msg().clone();
191+
request.respond(msg).unwrap();
192+
193+
drop(rep);
194+
});
195+
196+
let hello = Bytes::from_static(b"hello");
197+
let response = req.request(hello.clone()).await.unwrap();
198+
assert_eq!(hello, response, "expected {:?}, got {:?}", hello, response);
199+
200+
match req.request(hello.clone()).await {
201+
Ok(response) => assert_eq!(hello, response, "expected {:?}, got {:?}", hello, response),
202+
Err(e) => tracing::warn!("Error: {:?}", e),
203+
}
204+
205+
tokio::time::sleep(Duration::from_secs(60)).await;
206+
207+
tokio::spawn(async move {
208+
req.request(hello.clone()).await.unwrap();
209+
});
210+
211+
let mut rep = RepSocket::new(Tcp::default());
212+
rep.bind(addr).await.unwrap();
213+
214+
tokio::time::sleep(Duration::from_millis(10000)).await;
215+
}

0 commit comments

Comments
 (0)