Skip to content

Commit 5cc32e1

Browse files
authored
hook setsockopt (#316)
2 parents e6d85dc + 8d2fa34 commit 5cc32e1

File tree

9 files changed

+377
-26
lines changed

9 files changed

+377
-26
lines changed

core/src/common/constants.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pub enum Syscall {
7575
kevent,
7676
#[cfg(windows)]
7777
iocp,
78+
setsockopt,
7879
recv,
7980
#[cfg(windows)]
8081
WSARecv,

core/src/syscall/common.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
pub use crate::syscall::{is_blocking, is_non_blocking, set_blocking, set_errno, set_non_blocking};
1+
pub use crate::syscall::{
2+
is_blocking, is_non_blocking, recv_time_limit, send_time_limit, set_blocking, set_errno,
3+
set_non_blocking,
4+
};
25

36
pub extern "C" fn reset_errno() {
47
set_errno(0);

core/src/syscall/unix/mod.rs

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use crate::syscall_mod;
2+
use dashmap::DashMap;
3+
use once_cell::sync::Lazy;
24
use std::ffi::c_int;
35

46
macro_rules! impl_facade {
@@ -131,6 +133,7 @@ macro_rules! impl_nio_read {
131133
if blocking {
132134
$crate::syscall::common::set_non_blocking($fd);
133135
}
136+
let start_time = $crate::common::now();
134137
let mut r;
135138
loop {
136139
r = self.inner.$syscall(fn_ptr, $fd, $($arg, )*);
@@ -141,9 +144,13 @@ macro_rules! impl_nio_read {
141144
let error_kind = std::io::Error::last_os_error().kind();
142145
if error_kind == std::io::ErrorKind::WouldBlock {
143146
//wait read event
147+
let wait_time = std::time::Duration::from_nanos(start_time
148+
.saturating_add($crate::syscall::common::recv_time_limit($fd))
149+
.saturating_sub($crate::common::now()))
150+
.min($crate::common::constants::SLICE);
144151
if $crate::net::EventLoops::wait_read_event(
145152
$fd,
146-
Some($crate::common::constants::SLICE),
153+
Some(wait_time)
147154
).is_err() {
148155
break;
149156
}
@@ -182,6 +189,7 @@ macro_rules! impl_nio_read_buf {
182189
if blocking {
183190
$crate::syscall::common::set_non_blocking($fd);
184191
}
192+
let start_time = $crate::common::now();
185193
let mut received = 0;
186194
let mut r = 0;
187195
while received < $len {
@@ -203,12 +211,14 @@ macro_rules! impl_nio_read_buf {
203211
let error_kind = std::io::Error::last_os_error().kind();
204212
if error_kind == std::io::ErrorKind::WouldBlock {
205213
//wait read event
214+
let wait_time = std::time::Duration::from_nanos(start_time
215+
.saturating_add($crate::syscall::common::recv_time_limit($fd))
216+
.saturating_sub($crate::common::now()))
217+
.min($crate::common::constants::SLICE);
206218
if $crate::net::EventLoops::wait_read_event(
207219
$fd,
208-
Some($crate::common::constants::SLICE),
209-
)
210-
.is_err()
211-
{
220+
Some(wait_time)
221+
).is_err() {
212222
break;
213223
}
214224
} else if error_kind != std::io::ErrorKind::Interrupted {
@@ -247,6 +257,7 @@ macro_rules! impl_nio_read_iovec {
247257
$crate::syscall::common::set_non_blocking($fd);
248258
}
249259
let vec = unsafe { Vec::from_raw_parts($iov.cast_mut(), $iovcnt as usize, $iovcnt as usize) };
260+
let start_time = $crate::common::now();
250261
let mut length = 0;
251262
let mut received = 0usize;
252263
let mut r = 0;
@@ -296,9 +307,13 @@ macro_rules! impl_nio_read_iovec {
296307
let error_kind = std::io::Error::last_os_error().kind();
297308
if error_kind == std::io::ErrorKind::WouldBlock {
298309
//wait read event
310+
let wait_time = std::time::Duration::from_nanos(start_time
311+
.saturating_add($crate::syscall::common::recv_time_limit($fd))
312+
.saturating_sub($crate::common::now()))
313+
.min($crate::common::constants::SLICE);
299314
if $crate::net::EventLoops::wait_read_event(
300315
$fd,
301-
Some($crate::common::constants::SLICE)
316+
Some(wait_time)
302317
).is_err() {
303318
std::mem::forget(vec);
304319
if blocking {
@@ -350,6 +365,7 @@ macro_rules! impl_nio_write_buf {
350365
if blocking {
351366
$crate::syscall::common::set_non_blocking($fd);
352367
}
368+
let start_time = $crate::common::now();
353369
let mut sent = 0;
354370
let mut r = 0;
355371
while sent < $len {
@@ -371,9 +387,13 @@ macro_rules! impl_nio_write_buf {
371387
let error_kind = std::io::Error::last_os_error().kind();
372388
if error_kind == std::io::ErrorKind::WouldBlock {
373389
//wait write event
390+
let wait_time = std::time::Duration::from_nanos(start_time
391+
.saturating_add($crate::syscall::common::send_time_limit($fd))
392+
.saturating_sub($crate::common::now()))
393+
.min($crate::common::constants::SLICE);
374394
if $crate::net::EventLoops::wait_write_event(
375395
$fd,
376-
Some($crate::common::constants::SLICE),
396+
Some(wait_time),
377397
)
378398
.is_err()
379399
{
@@ -415,6 +435,7 @@ macro_rules! impl_nio_write_iovec {
415435
$crate::syscall::common::set_non_blocking($fd);
416436
}
417437
let vec = unsafe { Vec::from_raw_parts($iov.cast_mut(), $iovcnt as usize, $iovcnt as usize) };
438+
let start_time = $crate::common::now();
418439
let mut length = 0;
419440
let mut sent = 0usize;
420441
let mut r = 0;
@@ -458,9 +479,13 @@ macro_rules! impl_nio_write_iovec {
458479
let error_kind = std::io::Error::last_os_error().kind();
459480
if error_kind == std::io::ErrorKind::WouldBlock {
460481
//wait write event
482+
let wait_time = std::time::Duration::from_nanos(start_time
483+
.saturating_add($crate::syscall::common::send_time_limit($fd))
484+
.saturating_sub($crate::common::now()))
485+
.min($crate::common::constants::SLICE);
461486
if $crate::net::EventLoops::wait_write_event(
462487
$fd,
463-
Some($crate::common::constants::SLICE)
488+
Some(wait_time)
464489
).is_err() {
465490
std::mem::forget(vec);
466491
if blocking {
@@ -541,6 +566,7 @@ syscall_mod!(
541566
shutdown;
542567
sleep;
543568
socket;
569+
setsockopt;
544570
usleep;
545571
write;
546572
writev;
@@ -551,6 +577,10 @@ syscall_mod!(
551577
unlink
552578
);
553579

580+
static SEND_TIME_LIMIT: Lazy<DashMap<c_int, u64>> = Lazy::new(Default::default);
581+
582+
static RECV_TIME_LIMIT: Lazy<DashMap<c_int, u64>> = Lazy::new(Default::default);
583+
554584
extern "C" {
555585
#[cfg(not(any(target_os = "dragonfly", target_os = "vxworks")))]
556586
#[cfg_attr(
@@ -636,3 +666,63 @@ pub extern "C" fn is_non_blocking(fd: c_int) -> bool {
636666
}
637667
(flags & libc::O_NONBLOCK) != 0
638668
}
669+
670+
#[must_use]
671+
pub extern "C" fn send_time_limit(fd: c_int) -> u64 {
672+
SEND_TIME_LIMIT.get(&fd).map_or_else(
673+
|| unsafe {
674+
let mut tv: libc::timeval = std::mem::zeroed();
675+
let mut len = size_of::<libc::timeval>() as libc::socklen_t;
676+
assert_eq!(
677+
0,
678+
libc::getsockopt(
679+
fd,
680+
libc::SOL_SOCKET,
681+
libc::SO_SNDTIMEO,
682+
std::ptr::from_mut(&mut tv).cast(),
683+
&mut len,
684+
)
685+
);
686+
let mut time_limit = (tv.tv_sec as u64)
687+
.saturating_mul(1_000_000_000)
688+
.saturating_add((tv.tv_usec as u64).saturating_mul(1_000));
689+
if 0 == time_limit {
690+
// 取消超时
691+
time_limit = u64::MAX;
692+
}
693+
assert!(SEND_TIME_LIMIT.insert(fd, time_limit).is_none());
694+
time_limit
695+
},
696+
|v| *v.value(),
697+
)
698+
}
699+
700+
#[must_use]
701+
pub extern "C" fn recv_time_limit(fd: c_int) -> u64 {
702+
RECV_TIME_LIMIT.get(&fd).map_or_else(
703+
|| unsafe {
704+
let mut tv: libc::timeval = std::mem::zeroed();
705+
let mut len = size_of::<libc::timeval>() as libc::socklen_t;
706+
assert_eq!(
707+
0,
708+
libc::getsockopt(
709+
fd,
710+
libc::SOL_SOCKET,
711+
libc::SO_RCVTIMEO,
712+
std::ptr::from_mut(&mut tv).cast(),
713+
&mut len,
714+
)
715+
);
716+
let mut time_limit = (tv.tv_sec as u64)
717+
.saturating_mul(1_000_000_000)
718+
.saturating_add((tv.tv_usec as u64).saturating_mul(1_000));
719+
if 0 == time_limit {
720+
// 取消超时
721+
time_limit = u64::MAX;
722+
}
723+
assert!(RECV_TIME_LIMIT.insert(fd, time_limit).is_none());
724+
time_limit
725+
},
726+
|v| *v.value(),
727+
)
728+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use std::ffi::{c_int, c_void};
2+
use libc::socklen_t;
3+
use once_cell::sync::Lazy;
4+
use crate::syscall::unix::{RECV_TIME_LIMIT, SEND_TIME_LIMIT};
5+
6+
#[must_use]
7+
pub extern "C" fn setsockopt(
8+
fn_ptr: Option<&extern "C" fn(c_int, c_int, c_int, *const c_void, socklen_t) -> c_int>,
9+
socket: c_int,
10+
level: c_int,
11+
name: c_int,
12+
value: *const c_void,
13+
option_len: socklen_t
14+
) -> c_int{
15+
static CHAIN: Lazy<SetsockoptSyscallFacade<NioSetsockoptSyscall<RawSetsockoptSyscall>>> =
16+
Lazy::new(Default::default);
17+
CHAIN.setsockopt(fn_ptr, socket, level, name, value, option_len)
18+
}
19+
20+
trait SetsockoptSyscall {
21+
extern "C" fn setsockopt(
22+
&self,
23+
fn_ptr: Option<&extern "C" fn(c_int, c_int, c_int, *const c_void, socklen_t) -> c_int>,
24+
socket: c_int,
25+
level: c_int,
26+
name: c_int,
27+
value: *const c_void,
28+
option_len: socklen_t
29+
) -> c_int;
30+
}
31+
32+
impl_facade!(SetsockoptSyscallFacade, SetsockoptSyscall,
33+
setsockopt(socket: c_int, level: c_int, name: c_int, value: *const c_void, option_len: socklen_t) -> c_int
34+
);
35+
36+
#[repr(C)]
37+
#[derive(Debug, Default)]
38+
struct NioSetsockoptSyscall<I: SetsockoptSyscall> {
39+
inner: I,
40+
}
41+
42+
impl<I: SetsockoptSyscall> SetsockoptSyscall for NioSetsockoptSyscall<I> {
43+
extern "C" fn setsockopt(
44+
&self,
45+
fn_ptr: Option<&extern "C" fn(c_int, c_int, c_int, *const c_void, socklen_t) -> c_int>,
46+
socket: c_int,
47+
level: c_int,
48+
name: c_int,
49+
value: *const c_void,
50+
option_len: socklen_t
51+
) -> c_int {
52+
let r= self.inner.setsockopt(fn_ptr, socket, level, name, value, option_len);
53+
if 0 == r && libc::SOL_SOCKET == level {
54+
if libc::SO_SNDTIMEO == name {
55+
let tv = unsafe { &*value.cast::<libc::timeval>() };
56+
let mut time_limit = (tv.tv_sec as u64)
57+
.saturating_mul(1_000_000_000)
58+
.saturating_add((tv.tv_usec as u64).saturating_mul(1_000));
59+
if 0 == time_limit {
60+
// 取消超时
61+
time_limit = u64::MAX;
62+
}
63+
assert!(SEND_TIME_LIMIT.insert(socket, time_limit).is_none());
64+
} else if libc::SO_RCVTIMEO == name {
65+
let tv = unsafe { &*value.cast::<libc::timeval>() };
66+
let mut time_limit = (tv.tv_sec as u64)
67+
.saturating_mul(1_000_000_000)
68+
.saturating_add((tv.tv_usec as u64).saturating_mul(1_000));
69+
if 0 == time_limit {
70+
// 取消超时
71+
time_limit = u64::MAX;
72+
}
73+
assert!(RECV_TIME_LIMIT.insert(socket, time_limit).is_none());
74+
}
75+
}
76+
r
77+
}
78+
}
79+
80+
impl_raw!(RawSetsockoptSyscall, SetsockoptSyscall,
81+
setsockopt(socket: c_int, level: c_int, name: c_int, value: *const c_void, option_len: socklen_t) -> c_int
82+
);

0 commit comments

Comments
 (0)