Skip to content

Commit fb03875

Browse files
authored
Merge pull request #1900 from hermit-os/embedded-io-async
refactor(io): prepare for async I/O traits
2 parents 0c8198c + 34ebc81 commit fb03875

File tree

12 files changed

+225
-244
lines changed

12 files changed

+225
-244
lines changed

src/fd/mod.rs

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -233,25 +233,27 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
233233

234234
/// `accept` a connection on a socket
235235
#[cfg(any(feature = "net", feature = "vsock"))]
236-
async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
236+
async fn accept(
237+
&mut self,
238+
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
237239
Err(Errno::Inval)
238240
}
239241

240242
/// initiate a connection on a socket
241243
#[cfg(any(feature = "net", feature = "vsock"))]
242-
async fn connect(&self, _endpoint: Endpoint) -> io::Result<()> {
244+
async fn connect(&mut self, _endpoint: Endpoint) -> io::Result<()> {
243245
Err(Errno::Inval)
244246
}
245247

246248
/// `bind` a name to a socket
247249
#[cfg(any(feature = "net", feature = "vsock"))]
248-
async fn bind(&self, _name: ListenEndpoint) -> io::Result<()> {
250+
async fn bind(&mut self, _name: ListenEndpoint) -> io::Result<()> {
249251
Err(Errno::Inval)
250252
}
251253

252254
/// `listen` for connections on a socket
253255
#[cfg(any(feature = "net", feature = "vsock"))]
254-
async fn listen(&self, _backlog: i32) -> io::Result<()> {
256+
async fn listen(&mut self, _backlog: i32) -> io::Result<()> {
255257
Err(Errno::Inval)
256258
}
257259

@@ -310,7 +312,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
310312
}
311313

312314
/// Sets the file status flags.
313-
async fn set_status_flags(&self, _status_flags: StatusFlags) -> io::Result<()> {
315+
async fn set_status_flags(&mut self, _status_flags: StatusFlags) -> io::Result<()> {
314316
Err(Errno::Nosys)
315317
}
316318

@@ -337,19 +339,19 @@ pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result<usize> {
337339
return Ok(0);
338340
}
339341

340-
block_on(obj.read(buf), None)
342+
block_on(async { obj.read().await.read(buf).await }, None)
341343
}
342344

343345
pub(crate) fn lseek(fd: FileDescriptor, offset: isize, whence: SeekWhence) -> io::Result<isize> {
344346
let obj = get_object(fd)?;
345347

346-
block_on(obj.lseek(offset, whence), None)
348+
block_on(async { obj.read().await.lseek(offset, whence).await }, None)
347349
}
348350

349351
pub(crate) fn chmod(fd: FileDescriptor, mode: AccessPermission) -> io::Result<()> {
350352
let obj = get_object(fd)?;
351353

352-
block_on(obj.chmod(mode), None)
354+
block_on(async { obj.read().await.chmod(mode).await }, None)
353355
}
354356

355357
pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result<usize> {
@@ -359,12 +361,12 @@ pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result<usize> {
359361
return Ok(0);
360362
}
361363

362-
block_on(obj.write(buf), None)
364+
block_on(async { obj.read().await.write(buf).await }, None)
363365
}
364366

365367
pub(crate) fn truncate(fd: FileDescriptor, length: usize) -> io::Result<()> {
366368
let obj = get_object(fd)?;
367-
block_on(obj.truncate(length), None)
369+
block_on(async { obj.read().await.truncate(length).await }, None)
368370
}
369371

370372
async fn poll_fds(fds: &mut [PollFd]) -> io::Result<u64> {
@@ -375,7 +377,7 @@ async fn poll_fds(fds: &mut [PollFd]) -> io::Result<u64> {
375377
let fd = i.fd;
376378
i.revents = PollEvent::empty();
377379
if let Ok(obj) = core_scheduler().get_object(fd) {
378-
let mut pinned = core::pin::pin!(obj.poll(i.events));
380+
let mut pinned = core::pin::pin!(async { obj.read().await.poll(i.events).await });
379381
if let Ready(Ok(e)) = pinned.as_mut().poll(cx)
380382
&& !e.is_empty()
381383
{
@@ -416,7 +418,7 @@ pub fn poll(fds: &mut [PollFd], timeout: Option<Duration>) -> io::Result<u64> {
416418

417419
pub fn fstat(fd: FileDescriptor) -> io::Result<FileAttr> {
418420
let obj = get_object(fd)?;
419-
block_on(obj.fstat(), None)
421+
block_on(async { obj.read().await.fstat().await }, None)
420422
}
421423

422424
/// Wait for some event on a file descriptor.
@@ -440,16 +442,20 @@ pub fn fstat(fd: FileDescriptor) -> io::Result<FileAttr> {
440442
pub fn eventfd(initval: u64, flags: EventFlags) -> io::Result<FileDescriptor> {
441443
let obj = self::eventfd::EventFd::new(initval, flags);
442444

443-
let fd = core_scheduler().insert_object(Arc::new(obj))?;
445+
let fd = core_scheduler().insert_object(Arc::new(async_lock::RwLock::new(obj)))?;
444446

445447
Ok(fd)
446448
}
447449

448-
pub(crate) fn get_object(fd: FileDescriptor) -> io::Result<Arc<dyn ObjectInterface>> {
450+
pub(crate) fn get_object(
451+
fd: FileDescriptor,
452+
) -> io::Result<Arc<async_lock::RwLock<dyn ObjectInterface>>> {
449453
core_scheduler().get_object(fd)
450454
}
451455

452-
pub(crate) fn insert_object(obj: Arc<dyn ObjectInterface>) -> io::Result<FileDescriptor> {
456+
pub(crate) fn insert_object(
457+
obj: Arc<async_lock::RwLock<dyn ObjectInterface>>,
458+
) -> io::Result<FileDescriptor> {
453459
core_scheduler().insert_object(obj)
454460
}
455461

@@ -465,11 +471,13 @@ pub(crate) fn dup_object2(fd1: FileDescriptor, fd2: FileDescriptor) -> io::Resul
465471
core_scheduler().dup_object2(fd1, fd2)
466472
}
467473

468-
pub(crate) fn remove_object(fd: FileDescriptor) -> io::Result<Arc<dyn ObjectInterface>> {
474+
pub(crate) fn remove_object(
475+
fd: FileDescriptor,
476+
) -> io::Result<Arc<async_lock::RwLock<dyn ObjectInterface>>> {
469477
core_scheduler().remove_object(fd)
470478
}
471479

472480
pub(crate) fn isatty(fd: FileDescriptor) -> io::Result<bool> {
473481
let obj = get_object(fd)?;
474-
block_on(obj.isatty(), None)
482+
block_on(async { obj.read().await.isatty().await }, None)
475483
}

src/fd/socket/tcp.rs

Lines changed: 8 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ impl Socket {
110110
})
111111
.await
112112
}
113+
}
113114

115+
#[async_trait]
116+
impl ObjectInterface for Socket {
114117
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
115118
future::poll_fn(|cx| {
116119
self.with(|socket| match socket.state() {
@@ -275,7 +278,7 @@ impl Socket {
275278
}
276279
}
277280

278-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
281+
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
279282
#[allow(irrefutable_let_patterns)]
280283
if let Endpoint::Ip(endpoint) = endpoint {
281284
self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port()))
@@ -298,7 +301,9 @@ impl Socket {
298301
}
299302
}
300303

301-
async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> {
304+
async fn accept(
305+
&mut self,
306+
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
302307
if !self.is_listen {
303308
self.listen(DEFAULT_BACKLOG).await?;
304309
}
@@ -357,7 +362,7 @@ impl Socket {
357362
is_listen: false,
358363
};
359364

360-
Ok((socket, endpoint))
365+
Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
361366
}
362367

363368
async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
@@ -473,63 +478,3 @@ impl Drop for Socket {
473478
}
474479
}
475480
}
476-
477-
#[async_trait]
478-
impl ObjectInterface for async_lock::RwLock<Socket> {
479-
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
480-
self.read().await.poll(event).await
481-
}
482-
483-
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
484-
self.read().await.read(buffer).await
485-
}
486-
487-
async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
488-
self.read().await.write(buffer).await
489-
}
490-
491-
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
492-
self.write().await.bind(endpoint).await
493-
}
494-
495-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
496-
self.read().await.connect(endpoint).await
497-
}
498-
499-
async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
500-
let (socket, endpoint) = self.write().await.accept().await?;
501-
Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
502-
}
503-
504-
async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
505-
self.read().await.getpeername().await
506-
}
507-
508-
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
509-
self.read().await.getsockname().await
510-
}
511-
512-
async fn listen(&self, backlog: i32) -> io::Result<()> {
513-
self.write().await.listen(backlog).await
514-
}
515-
516-
async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
517-
self.read().await.setsockopt(opt, optval).await
518-
}
519-
520-
async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
521-
self.read().await.getsockopt(opt).await
522-
}
523-
524-
async fn shutdown(&self, how: i32) -> io::Result<()> {
525-
self.read().await.shutdown(how).await
526-
}
527-
528-
async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
529-
self.read().await.status_flags().await
530-
}
531-
532-
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
533-
self.write().await.set_status_flags(status_flags).await
534-
}
535-
}

src/fd/socket/udp.rs

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ impl Socket {
7474
})
7575
.await
7676
}
77+
}
7778

79+
#[async_trait]
80+
impl ObjectInterface for Socket {
7881
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
7982
future::poll_fn(|cx| {
8083
self.with(|socket| {
@@ -251,46 +254,3 @@ impl Drop for Socket {
251254
NIC.lock().as_nic_mut().unwrap().destroy_socket(self.handle);
252255
}
253256
}
254-
255-
#[async_trait]
256-
impl ObjectInterface for async_lock::RwLock<Socket> {
257-
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
258-
self.read().await.poll(event).await
259-
}
260-
261-
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
262-
self.write().await.bind(endpoint).await
263-
}
264-
265-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
266-
self.write().await.connect(endpoint).await
267-
}
268-
269-
async fn sendto(&self, buffer: &[u8], endpoint: Endpoint) -> io::Result<usize> {
270-
self.read().await.sendto(buffer, endpoint).await
271-
}
272-
273-
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
274-
self.read().await.recvfrom(buffer).await
275-
}
276-
277-
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
278-
self.read().await.read(buffer).await
279-
}
280-
281-
async fn write(&self, buf: &[u8]) -> io::Result<usize> {
282-
self.read().await.write(buf).await
283-
}
284-
285-
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
286-
self.read().await.getsockname().await
287-
}
288-
289-
async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
290-
self.read().await.status_flags().await
291-
}
292-
293-
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
294-
self.write().await.set_status_flags(status_flags).await
295-
}
296-
}

0 commit comments

Comments
 (0)