Skip to content

Commit 09cd926

Browse files
author
Hui Zhu
authored
Merge pull request #38 from Tim-Zhang/fix-async-vsock
Fix vsock in async mode
2 parents 426ff24 + 1e8a7d2 commit 09cd926

File tree

5 files changed

+96
-26
lines changed

5 files changed

+96
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async-trait = "0.1.31"
2121

2222
tokio = { version = "0.2", features = ["rt-threaded", "sync", "uds", "stream", "macros", "io-util"] }
2323
futures = "0.3"
24+
tokio-vsock = "0.2.1"
2425

2526
[build-dependencies]
2627
protobuf-codegen-pure = "2.14.0"

example/Cargo.lock

Lines changed: 30 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/asynchronous/server.rs

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,37 @@ use std::os::unix::io::RawFd;
99
use std::sync::Arc;
1010

1111
use crate::asynchronous::stream::{receive, respond, respond_with_status};
12-
use crate::common;
13-
use crate::common::MESSAGE_TYPE_REQUEST;
12+
use crate::common::{self, Domain, MESSAGE_TYPE_REQUEST};
1413
use crate::error::{get_status, Error, Result};
1514
use crate::r#async::{MethodHandler, TtrpcContext};
1615
use crate::ttrpc::{Code, Request};
1716
use crate::MessageHeader;
1817
use futures::StreamExt as _;
18+
use std::marker::Unpin;
1919
use std::os::unix::io::FromRawFd;
20+
use std::os::unix::net::UnixListener as SysUnixListener;
2021
use tokio::{
2122
self,
2223
io::split,
2324
net::UnixListener,
2425
prelude::*,
26+
stream::Stream,
2527
sync::mpsc::{channel, Receiver, Sender},
2628
};
29+
use tokio_vsock::VsockListener;
2730

2831
pub struct Server {
2932
listeners: Vec<RawFd>,
3033
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
34+
domain: Option<Domain>,
3135
}
3236

3337
impl Default for Server {
3438
fn default() -> Self {
3539
Server {
3640
listeners: Vec::with_capacity(1),
3741
methods: Arc::new(HashMap::new()),
42+
domain: None,
3843
}
3944
}
4045
}
@@ -51,9 +56,10 @@ impl Server {
5156
));
5257
}
5358

54-
let fd = common::do_bind(host)?;
55-
self.listeners.push(fd);
59+
let (fd, domain) = common::do_bind(host)?;
60+
self.domain = Some(domain);
5661

62+
self.listeners.push(fd);
5763
Ok(self)
5864
}
5965

@@ -77,21 +83,42 @@ impl Server {
7783
return Err(Error::Others("ttrpc-rust not bind".to_string()));
7884
}
7985

80-
let listener = self.listeners[0];
81-
common::do_listen(listener)?;
86+
let listenfd = self.listeners[0];
87+
common::do_listen(listenfd)?;
8288

83-
Ok(listener)
89+
Ok(listenfd)
8490
}
8591

8692
pub async fn start(&self) -> Result<()> {
87-
let listener = self.listen()?;
88-
let sys_unix_listener: std::os::unix::net::UnixListener;
89-
unsafe {
90-
sys_unix_listener = std::os::unix::net::UnixListener::from_raw_fd(listener);
93+
let listenfd = self.listen()?;
94+
95+
match self.domain.as_ref().unwrap() {
96+
Domain::Unix => {
97+
let sys_unix_listener;
98+
unsafe {
99+
sys_unix_listener = SysUnixListener::from_raw_fd(listenfd);
100+
}
101+
let mut unix_listener = UnixListener::from_std(sys_unix_listener).unwrap();
102+
let incoming = unix_listener.incoming();
103+
104+
self.do_start(listenfd, incoming).await
105+
}
106+
Domain::Vsock => {
107+
let incoming;
108+
unsafe {
109+
incoming = VsockListener::from_raw_fd(listenfd).incoming();
110+
}
111+
112+
self.do_start(listenfd, incoming).await
113+
}
91114
}
92-
let mut unix_listener = UnixListener::from_std(sys_unix_listener).unwrap();
93-
let mut incoming = unix_listener.incoming();
115+
}
94116

117+
pub async fn do_start<I, S>(&self, listenfd: RawFd, mut incoming: I) -> Result<()>
118+
where
119+
I: Stream<Item = std::io::Result<S>> + Unpin,
120+
S: AsyncRead + AsyncWrite + Send + 'static,
121+
{
95122
while let Some(result) = incoming.next().await {
96123
match result {
97124
Ok(stream) => {
@@ -115,7 +142,7 @@ impl Server {
115142
match receive(&mut reader).await {
116143
Ok(message) => {
117144
tokio::spawn(async move {
118-
handle_request(tx, listener, methods, message).await;
145+
handle_request(tx, listenfd, methods, message).await;
119146
});
120147
}
121148
Err(e) => {

src/common.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ use nix::sys::socket::*;
1111
use std::os::unix::io::RawFd;
1212
use std::str::FromStr;
1313

14+
#[derive(Debug)]
15+
pub enum Domain {
16+
Unix,
17+
Vsock,
18+
}
19+
1420
#[derive(Default, Debug)]
1521
pub struct MessageHeader {
1622
pub length: u32,
@@ -30,18 +36,29 @@ pub fn do_listen(listener: RawFd) -> Result<()> {
3036
listen(listener, 10).map_err(|e| Error::Socket(e.to_string()))
3137
}
3238

33-
pub fn do_bind(host: &str) -> Result<RawFd> {
39+
pub fn parse_host(host: &str) -> Result<(Domain, Vec<&str>)> {
3440
let hostv: Vec<&str> = host.trim().split("://").collect();
3541
if hostv.len() != 2 {
3642
return Err(Error::Others(format!("Host {} is not right", host)));
3743
}
38-
let scheme = hostv[0].to_lowercase();
44+
45+
let domain = match &hostv[0].to_lowercase()[..] {
46+
"unix" => Domain::Unix,
47+
"vsock" => Domain::Vsock,
48+
x => return Err(Error::Others(format!("Scheme {:?} is not supported", x))),
49+
};
50+
51+
Ok((domain, hostv))
52+
}
53+
54+
pub fn do_bind(host: &str) -> Result<(RawFd, Domain)> {
55+
let (domain, hostv) = parse_host(host)?;
3956

4057
let sockaddr: SockAddr;
4158
let fd: RawFd;
4259

43-
match scheme.as_str() {
44-
"unix" => {
60+
match domain {
61+
Domain::Unix => {
4562
fd = socket(
4663
AddressFamily::Unix,
4764
SockType::Stream,
@@ -54,8 +71,7 @@ pub fn do_bind(host: &str) -> Result<RawFd> {
5471
UnixAddr::new_abstract(sockaddr_h.as_bytes()).map_err(err_to_Others!(e, ""))?;
5572
sockaddr = SockAddr::Unix(sockaddr_u);
5673
}
57-
58-
"vsock" => {
74+
Domain::Vsock => {
5975
let host_port_v: Vec<&str> = hostv[1].split(':').collect();
6076
if host_port_v.len() != 2 {
6177
return Err(Error::Others(format!(
@@ -75,12 +91,11 @@ pub fn do_bind(host: &str) -> Result<RawFd> {
7591
.map_err(|e| Error::Socket(e.to_string()))?;
7692
sockaddr = SockAddr::new_vsock(cid, port);
7793
}
78-
_ => return Err(Error::Others(format!("Scheme {} is not supported", scheme))),
7994
};
8095

8196
bind(fd, &sockaddr).map_err(err_to_Others!(e, ""))?;
8297

83-
Ok(fd)
98+
Ok((fd, domain))
8499
}
85100

86101
macro_rules! cfg_sync {

src/sync/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ impl Server {
275275
));
276276
}
277277

278-
let fd = common::do_bind(host)?;
278+
let (fd, _) = common::do_bind(host)?;
279279
self.listeners.push(fd);
280280

281281
Ok(self)

0 commit comments

Comments
 (0)