Skip to content

Commit 74a3653

Browse files
authored
Fix: Recalculate the timeout duration considering open_timeout (ruby#15596)
This change updates the behavior so that, when there is only a single destination and `open_timeout` is specified, the remaining `open_timeout` duration is used as the connection timeout.
1 parent f483484 commit 74a3653

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

ext/socket/ipsocket.c

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ init_fast_fallback_inetsock_internal(VALUE v)
606606
struct timeval user_specified_open_timeout_storage;
607607
struct timeval *user_specified_open_timeout_at = NULL;
608608
struct timespec now = current_clocktime_ts();
609+
VALUE starts_at = current_clocktime();
609610

610611
if (!NIL_P(open_timeout)) {
611612
struct timeval open_timeout_tv = rb_time_interval(open_timeout);
@@ -619,7 +620,14 @@ init_fast_fallback_inetsock_internal(VALUE v)
619620
arg->getaddrinfo_shared = NULL;
620621

621622
int family = arg->families[0];
622-
unsigned int t = NIL_P(resolv_timeout) ? 0 : rsock_value_timeout_to_msec(resolv_timeout);
623+
unsigned int t;
624+
if (!NIL_P(open_timeout)) {
625+
t = rsock_value_timeout_to_msec(open_timeout);
626+
} else if (!NIL_P(open_timeout)) {
627+
t = rsock_value_timeout_to_msec(resolv_timeout);
628+
} else {
629+
t = 0;
630+
}
623631

624632
arg->remote.res = rsock_addrinfo(
625633
arg->remote.host,
@@ -833,14 +841,22 @@ init_fast_fallback_inetsock_internal(VALUE v)
833841
status = connect(fd, remote_ai->ai_addr, remote_ai->ai_addrlen);
834842
last_family = remote_ai->ai_family;
835843
} else {
836-
if (!NIL_P(connect_timeout)) {
837-
user_specified_connect_timeout_storage = rb_time_interval(connect_timeout);
838-
user_specified_connect_timeout_at = &user_specified_connect_timeout_storage;
844+
VALUE timeout = Qnil;
845+
846+
if (!NIL_P(open_timeout)) {
847+
VALUE elapsed = rb_funcall(current_clocktime(), '-', 1, starts_at);
848+
timeout = rb_funcall(open_timeout, '-', 1, elapsed);
849+
}
850+
if (NIL_P(timeout)) {
851+
if (!NIL_P(connect_timeout)) {
852+
user_specified_connect_timeout_storage = rb_time_interval(connect_timeout);
853+
user_specified_connect_timeout_at = &user_specified_connect_timeout_storage;
854+
}
855+
timeout =
856+
(user_specified_connect_timeout_at && is_infinity(*user_specified_connect_timeout_at)) ?
857+
Qnil : tv_to_seconds(user_specified_connect_timeout_at);
839858
}
840859

841-
VALUE timeout =
842-
(user_specified_connect_timeout_at && is_infinity(*user_specified_connect_timeout_at)) ?
843-
Qnil : tv_to_seconds(user_specified_connect_timeout_at);
844860
io = arg->io = rsock_init_sock(arg->self, fd);
845861
status = rsock_connect(io, remote_ai->ai_addr, remote_ai->ai_addrlen, 0, timeout);
846862
}
@@ -1305,13 +1321,22 @@ rsock_init_inetsock(
13051321
* Maybe also accept a local address
13061322
*/
13071323
if (!NIL_P(local_host) || !NIL_P(local_serv)) {
1324+
unsigned int t;
1325+
if (!NIL_P(open_timeout)) {
1326+
t = rsock_value_timeout_to_msec(open_timeout);
1327+
} else if (!NIL_P(open_timeout)) {
1328+
t = rsock_value_timeout_to_msec(resolv_timeout);
1329+
} else {
1330+
t = 0;
1331+
}
1332+
13081333
local_res = rsock_addrinfo(
13091334
local_host,
13101335
local_serv,
13111336
AF_UNSPEC,
13121337
SOCK_STREAM,
13131338
0,
1314-
0
1339+
t
13151340
);
13161341

13171342
struct addrinfo *tmp_p = local_res->ai;

ext/socket/lib/socket.rb

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def self.tcp(host, port, local_host = nil, local_port = nil, connect_timeout: ni
685685
# :stopdoc:
686686
def self.tcp_with_fast_fallback(host, port, local_host = nil, local_port = nil, connect_timeout: nil, resolv_timeout: nil, open_timeout: nil)
687687
if local_host || local_port
688-
local_addrinfos = Addrinfo.getaddrinfo(local_host, local_port, nil, :STREAM, timeout: resolv_timeout)
688+
local_addrinfos = Addrinfo.getaddrinfo(local_host, local_port, nil, :STREAM, timeout: open_timeout || resolv_timeout)
689689
resolving_family_names = local_addrinfos.map { |lai| ADDRESS_FAMILIES.key(lai.afamily) }.uniq
690690
else
691691
local_addrinfos = []
@@ -698,6 +698,7 @@ def self.tcp_with_fast_fallback(host, port, local_host = nil, local_port = nil,
698698
is_windows_environment ||= (RUBY_PLATFORM =~ /mswin|mingw|cygwin/)
699699

700700
now = current_clock_time
701+
starts_at = now
701702
resolution_delay_expires_at = nil
702703
connection_attempt_delay_expires_at = nil
703704
user_specified_connect_timeout_at = nil
@@ -707,7 +708,7 @@ def self.tcp_with_fast_fallback(host, port, local_host = nil, local_port = nil,
707708

708709
if resolving_family_names.size == 1
709710
family_name = resolving_family_names.first
710-
addrinfos = Addrinfo.getaddrinfo(host, port, ADDRESS_FAMILIES[:family_name], :STREAM, timeout: resolv_timeout)
711+
addrinfos = Addrinfo.getaddrinfo(host, port, ADDRESS_FAMILIES[:family_name], :STREAM, timeout: open_timeout || resolv_timeout)
711712
resolution_store.add_resolved(family_name, addrinfos)
712713
hostname_resolution_result = nil
713714
hostname_resolution_notifier = nil
@@ -724,7 +725,6 @@ def self.tcp_with_fast_fallback(host, port, local_host = nil, local_port = nil,
724725
thread
725726
}
726727
)
727-
728728
user_specified_resolv_timeout_at = resolv_timeout ? now + resolv_timeout : Float::INFINITY
729729
end
730730

@@ -758,9 +758,16 @@ def self.tcp_with_fast_fallback(host, port, local_host = nil, local_port = nil,
758758
socket.bind(local_addrinfo) if local_addrinfo
759759
result = socket.connect_nonblock(addrinfo, exception: false)
760760
else
761+
timeout =
762+
if open_timeout
763+
t = open_timeout - (current_clock_time - starts_at)
764+
t.negative? ? 0 : t
765+
else
766+
connect_timeout
767+
end
761768
result = socket = local_addrinfo ?
762-
addrinfo.connect_from(local_addrinfo, timeout: connect_timeout) :
763-
addrinfo.connect(timeout: connect_timeout)
769+
addrinfo.connect_from(local_addrinfo, timeout:) :
770+
addrinfo.connect(timeout:)
764771
end
765772

766773
if result == :wait_writable
@@ -934,7 +941,7 @@ def self.tcp_without_fast_fallback(host, port, local_host, local_port, connect_t
934941

935942
local_addr_list = nil
936943
if local_host != nil || local_port != nil
937-
local_addr_list = Addrinfo.getaddrinfo(local_host, local_port, nil, :STREAM, nil)
944+
local_addr_list = Addrinfo.getaddrinfo(local_host, local_port, nil, :STREAM, nil, timeout: open_timeout || resolv_timeout)
938945
end
939946

940947
timeout = open_timeout ? open_timeout : resolv_timeout

0 commit comments

Comments
 (0)