Skip to content

Commit 556e1a1

Browse files
committed
refactor(socket): implement ObjectInterface for sockets directly
1 parent ec0fc35 commit 556e1a1

File tree

4 files changed

+43
-46
lines changed

4 files changed

+43
-46
lines changed

src/fd/socket/tcp.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -475,66 +475,63 @@ impl Drop for Socket {
475475
}
476476

477477
#[async_trait]
478-
impl ObjectInterface for async_lock::RwLock<Socket> {
478+
impl ObjectInterface for Socket {
479479
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
480-
self.read().await.poll(event).await
480+
self.poll(event).await
481481
}
482482

483483
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
484-
self.read().await.read(buffer).await
484+
self.read(buffer).await
485485
}
486486

487487
async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
488-
self.read().await.write(buffer).await
488+
self.write(buffer).await
489489
}
490490

491491
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
492-
self.write().await.bind(endpoint).await
492+
self.bind(endpoint).await
493493
}
494494

495495
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
496-
self.write().await.connect(endpoint).await
496+
self.connect(endpoint).await
497497
}
498498

499499
async fn accept(
500500
&mut self,
501501
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
502-
let (socket, endpoint) = self.write().await.accept().await?;
503-
Ok((
504-
Arc::new(async_lock::RwLock::new(async_lock::RwLock::new(socket))),
505-
endpoint,
506-
))
502+
let (socket, endpoint) = self.accept().await?;
503+
Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
507504
}
508505

509506
async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
510-
self.read().await.getpeername().await
507+
self.getpeername().await
511508
}
512509

513510
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
514-
self.read().await.getsockname().await
511+
self.getsockname().await
515512
}
516513

517514
async fn listen(&mut self, backlog: i32) -> io::Result<()> {
518-
self.write().await.listen(backlog).await
515+
self.listen(backlog).await
519516
}
520517

521518
async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
522-
self.read().await.setsockopt(opt, optval).await
519+
self.setsockopt(opt, optval).await
523520
}
524521

525522
async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
526-
self.read().await.getsockopt(opt).await
523+
self.getsockopt(opt).await
527524
}
528525

529526
async fn shutdown(&self, how: i32) -> io::Result<()> {
530-
self.read().await.shutdown(how).await
527+
self.shutdown(how).await
531528
}
532529

533530
async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
534-
self.read().await.status_flags().await
531+
self.status_flags().await
535532
}
536533

537534
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
538-
self.write().await.set_status_flags(status_flags).await
535+
self.set_status_flags(status_flags).await
539536
}
540537
}

src/fd/socket/udp.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -253,44 +253,44 @@ impl Drop for Socket {
253253
}
254254

255255
#[async_trait]
256-
impl ObjectInterface for async_lock::RwLock<Socket> {
256+
impl ObjectInterface for Socket {
257257
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
258-
self.read().await.poll(event).await
258+
self.poll(event).await
259259
}
260260

261261
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
262-
self.write().await.bind(endpoint).await
262+
self.bind(endpoint).await
263263
}
264264

265265
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
266-
self.write().await.connect(endpoint).await
266+
self.connect(endpoint).await
267267
}
268268

269269
async fn sendto(&self, buffer: &[u8], endpoint: Endpoint) -> io::Result<usize> {
270-
self.read().await.sendto(buffer, endpoint).await
270+
self.sendto(buffer, endpoint).await
271271
}
272272

273273
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
274-
self.read().await.recvfrom(buffer).await
274+
self.recvfrom(buffer).await
275275
}
276276

277277
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
278-
self.read().await.read(buffer).await
278+
self.read(buffer).await
279279
}
280280

281281
async fn write(&self, buf: &[u8]) -> io::Result<usize> {
282-
self.read().await.write(buf).await
282+
self.write(buf).await
283283
}
284284

285285
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
286-
self.read().await.getsockname().await
286+
self.getsockname().await
287287
}
288288

289289
async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
290-
self.read().await.status_flags().await
290+
self.status_flags().await
291291
}
292292

293293
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
294-
self.write().await.set_status_flags(status_flags).await
294+
self.set_status_flags(status_flags).await
295295
}
296296
}

src/fd/socket/vsock.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -421,55 +421,55 @@ impl Drop for Socket {
421421
}
422422

423423
#[async_trait]
424-
impl ObjectInterface for async_lock::RwLock<Socket> {
424+
impl ObjectInterface for Socket {
425425
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
426-
self.read().await.poll(event).await
426+
self.poll(event).await
427427
}
428428

429429
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
430-
self.read().await.read(buffer).await
430+
self.read(buffer).await
431431
}
432432

433433
async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
434-
self.read().await.write(buffer).await
434+
self.write(buffer).await
435435
}
436436

437437
async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
438-
self.write().await.bind(endpoint).await
438+
self.bind(endpoint).await
439439
}
440440

441441
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
442-
self.write().await.connect(endpoint).await
442+
self.connect(endpoint).await
443443
}
444444

445445
async fn accept(
446446
&mut self,
447447
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
448-
let (handle, endpoint) = self.write().await.accept().await?;
448+
let (handle, endpoint) = self.accept().await?;
449449
Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint))
450450
}
451451

452452
async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
453-
self.read().await.getpeername().await
453+
self.getpeername().await
454454
}
455455

456456
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
457-
self.read().await.getsockname().await
457+
self.getsockname().await
458458
}
459459

460460
async fn listen(&mut self, backlog: i32) -> io::Result<()> {
461-
self.write().await.listen(backlog).await
461+
self.listen(backlog).await
462462
}
463463

464464
async fn shutdown(&self, how: i32) -> io::Result<()> {
465-
self.read().await.shutdown(how).await
465+
self.shutdown(how).await
466466
}
467467

468468
async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
469-
self.read().await.status_flags().await
469+
self.status_flags().await
470470
}
471471

472472
async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
473-
self.write().await.set_status_flags(status_flags).await
473+
self.set_status_flags(status_flags).await
474474
}
475475
}

src/syscalls/socket/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
588588

589589
#[cfg(feature = "vsock")]
590590
if domain == Af::Vsock && sock == Sock::Stream {
591-
let mut socket = async_lock::RwLock::new(vsock::Socket::new());
591+
let mut socket = vsock::Socket::new();
592592

593593
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
594594
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -610,7 +610,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
610610
if sock == Sock::Dgram {
611611
let handle = nic.create_udp_handle().unwrap();
612612
drop(guard);
613-
let mut socket = async_lock::RwLock::new(udp::Socket::new(handle, domain));
613+
let mut socket = udp::Socket::new(handle, domain);
614614

615615
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
616616
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -626,7 +626,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
626626
if sock == Sock::Stream {
627627
let handle = nic.create_tcp_handle().unwrap();
628628
drop(guard);
629-
let mut socket = async_lock::RwLock::new(tcp::Socket::new(handle, domain));
629+
let mut socket = tcp::Socket::new(handle, domain);
630630

631631
if sock_flags.contains(SockFlags::SOCK_NONBLOCK) {
632632
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();

0 commit comments

Comments
 (0)