Skip to content

Commit ec0fc35

Browse files
committed
refactor(fd): make accept, bind, connect, listen, set_status_flags mutable
1 parent 884cdcc commit ec0fc35

File tree

6 files changed

+38
-32
lines changed

6 files changed

+38
-32
lines changed

src/fd/mod.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,25 +233,27 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
233233

234234
/// `accept` a connection on a socket
235235
#[cfg(any(feature = "net", feature = "vsock"))]
236-
async fn accept(&self) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
236+
async fn accept(
237+
&mut self,
238+
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
237239
Err(Errno::Inval)
238240
}
239241

240242
/// initiate a connection on a socket
241243
#[cfg(any(feature = "net", feature = "vsock"))]
242-
async fn connect(&self, _endpoint: Endpoint) -> io::Result<()> {
244+
async fn connect(&mut self, _endpoint: Endpoint) -> io::Result<()> {
243245
Err(Errno::Inval)
244246
}
245247

246248
/// `bind` a name to a socket
247249
#[cfg(any(feature = "net", feature = "vsock"))]
248-
async fn bind(&self, _name: ListenEndpoint) -> io::Result<()> {
250+
async fn bind(&mut self, _name: ListenEndpoint) -> io::Result<()> {
249251
Err(Errno::Inval)
250252
}
251253

252254
/// `listen` for connections on a socket
253255
#[cfg(any(feature = "net", feature = "vsock"))]
254-
async fn listen(&self, _backlog: i32) -> io::Result<()> {
256+
async fn listen(&mut self, _backlog: i32) -> io::Result<()> {
255257
Err(Errno::Inval)
256258
}
257259

@@ -310,7 +312,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
310312
}
311313

312314
/// Sets the file status flags.
313-
async fn set_status_flags(&self, _status_flags: StatusFlags) -> io::Result<()> {
315+
async fn set_status_flags(&mut self, _status_flags: StatusFlags) -> io::Result<()> {
314316
Err(Errno::Nosys)
315317
}
316318

src/fd/socket/tcp.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ impl Socket {
275275
}
276276
}
277277

278-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
278+
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
279279
#[allow(irrefutable_let_patterns)]
280280
if let Endpoint::Ip(endpoint) = endpoint {
281281
self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port()))
@@ -488,15 +488,17 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
488488
self.read().await.write(buffer).await
489489
}
490490

491-
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
491+
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
492492
self.write().await.bind(endpoint).await
493493
}
494494

495-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
496-
self.read().await.connect(endpoint).await
495+
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
496+
self.write().await.connect(endpoint).await
497497
}
498498

499-
async fn accept(&self) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
499+
async fn accept(
500+
&mut self,
501+
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
500502
let (socket, endpoint) = self.write().await.accept().await?;
501503
Ok((
502504
Arc::new(async_lock::RwLock::new(async_lock::RwLock::new(socket))),
@@ -512,7 +514,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
512514
self.read().await.getsockname().await
513515
}
514516

515-
async fn listen(&self, backlog: i32) -> io::Result<()> {
517+
async fn listen(&mut self, backlog: i32) -> io::Result<()> {
516518
self.write().await.listen(backlog).await
517519
}
518520

@@ -532,7 +534,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
532534
self.read().await.status_flags().await
533535
}
534536

535-
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
537+
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
536538
self.write().await.set_status_flags(status_flags).await
537539
}
538540
}

src/fd/socket/udp.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,11 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
258258
self.read().await.poll(event).await
259259
}
260260

261-
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
261+
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
262262
self.write().await.bind(endpoint).await
263263
}
264264

265-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
265+
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
266266
self.write().await.connect(endpoint).await
267267
}
268268

@@ -290,7 +290,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
290290
self.read().await.status_flags().await
291291
}
292292

293-
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
293+
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
294294
self.write().await.set_status_flags(status_flags).await
295295
}
296296
}

src/fd/socket/vsock.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ impl Socket {
232232
))))
233233
}
234234

235-
async fn listen(&self, _backlog: i32) -> io::Result<()> {
235+
async fn listen(&mut self, _backlog: i32) -> io::Result<()> {
236236
Ok(())
237237
}
238238

@@ -434,15 +434,17 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
434434
self.read().await.write(buffer).await
435435
}
436436

437-
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
437+
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
438438
self.write().await.bind(endpoint).await
439439
}
440440

441-
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
441+
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
442442
self.write().await.connect(endpoint).await
443443
}
444444

445-
async fn accept(&self) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
445+
async fn accept(
446+
&mut self,
447+
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
446448
let (handle, endpoint) = self.write().await.accept().await?;
447449
Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint))
448450
}
@@ -455,7 +457,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
455457
self.read().await.getsockname().await
456458
}
457459

458-
async fn listen(&self, backlog: i32) -> io::Result<()> {
460+
async fn listen(&mut self, backlog: i32) -> io::Result<()> {
459461
self.write().await.listen(backlog).await
460462
}
461463

@@ -467,7 +469,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
467469
self.read().await.status_flags().await
468470
}
469471

470-
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
472+
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
471473
self.write().await.set_status_flags(status_flags).await
472474
}
473475
}

src/syscalls/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ pub unsafe extern "C" fn sys_ioctl(
642642
|e| -i32::from(e),
643643
|v| {
644644
block_on(
645-
async { v.read().await.set_status_flags(status_flags).await },
645+
async { v.write().await.set_status_flags(status_flags).await },
646646
None,
647647
)
648648
.map_or_else(|e| -i32::from(e), |()| 0)
@@ -680,7 +680,7 @@ pub extern "C" fn sys_fcntl(fd: i32, cmd: i32, arg: i32) -> i32 {
680680
|v| {
681681
block_on(
682682
async {
683-
v.read()
683+
v.write()
684684
.await
685685
.set_status_flags(fd::StatusFlags::from_bits_retain(arg))
686686
.await

src/syscalls/socket/mod.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
588588

589589
#[cfg(feature = "vsock")]
590590
if domain == Af::Vsock && sock == Sock::Stream {
591-
let socket = async_lock::RwLock::new(vsock::Socket::new());
591+
let mut socket = async_lock::RwLock::new(vsock::Socket::new());
592592

593593
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
594594
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -610,7 +610,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
610610
if sock == Sock::Dgram {
611611
let handle = nic.create_udp_handle().unwrap();
612612
drop(guard);
613-
let socket = async_lock::RwLock::new(udp::Socket::new(handle, domain));
613+
let mut socket = async_lock::RwLock::new(udp::Socket::new(handle, domain));
614614

615615
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
616616
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -626,7 +626,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
626626
if sock == Sock::Stream {
627627
let handle = nic.create_tcp_handle().unwrap();
628628
drop(guard);
629-
let socket = async_lock::RwLock::new(tcp::Socket::new(handle, domain));
629+
let mut socket = async_lock::RwLock::new(tcp::Socket::new(handle, domain));
630630

631631
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
632632
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -650,7 +650,7 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut
650650
obj.map_or_else(
651651
|e| -i32::from(e),
652652
|v| {
653-
block_on(async { v.read().await.accept().await }, None).map_or_else(
653+
block_on(async { v.write().await.accept().await }, None).map_or_else(
654654
|e| -i32::from(e),
655655
#[cfg_attr(not(feature = "net"), expect(unused_variables))]
656656
|(obj, endpoint)| match endpoint {
@@ -712,7 +712,7 @@ pub extern "C" fn sys_listen(fd: i32, backlog: i32) -> i32 {
712712
obj.map_or_else(
713713
|e| -i32::from(e),
714714
|v| {
715-
block_on(async { v.read().await.listen(backlog).await }, None)
715+
block_on(async { v.write().await.listen(backlog).await }, None)
716716
.map_or_else(|e| -i32::from(e), |()| 0)
717717
},
718718
)
@@ -740,7 +740,7 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl
740740
}
741741
let endpoint = IpListenEndpoint::from(unsafe { *name.cast::<sockaddr_in>() });
742742
block_on(
743-
async { v.read().await.bind(ListenEndpoint::Ip(endpoint)).await },
743+
async { v.write().await.bind(ListenEndpoint::Ip(endpoint)).await },
744744
None,
745745
)
746746
.map_or_else(|e| -i32::from(e), |()| 0)
@@ -752,7 +752,7 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl
752752
}
753753
let endpoint = IpListenEndpoint::from(unsafe { *name.cast::<sockaddr_in6>() });
754754
block_on(
755-
async { v.read().await.bind(ListenEndpoint::Ip(endpoint)).await },
755+
async { v.write().await.bind(ListenEndpoint::Ip(endpoint)).await },
756756
None,
757757
)
758758
.map_or_else(|e| -i32::from(e), |()| 0)
@@ -764,7 +764,7 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl
764764
}
765765
let endpoint = VsockListenEndpoint::from(unsafe { *name.cast::<sockaddr_vm>() });
766766
block_on(
767-
async { v.read().await.bind(ListenEndpoint::Vsock(endpoint)).await },
767+
async { v.write().await.bind(ListenEndpoint::Vsock(endpoint)).await },
768768
None,
769769
)
770770
.map_or_else(|e| -i32::from(e), |()| 0)
@@ -816,7 +816,7 @@ pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: so
816816
obj.map_or_else(
817817
|e| -i32::from(e),
818818
|v| {
819-
block_on(async { v.read().await.connect(endpoint).await }, None)
819+
block_on(async { v.write().await.connect(endpoint).await }, None)
820820
.map_or_else(|e| -i32::from(e), |()| 0)
821821
},
822822
)

0 commit comments

Comments
 (0)