diff --git a/lib/dalli/server.rb b/lib/dalli/server.rb index 8e089a73..15ddd64f 100644 --- a/lib/dalli/server.rb +++ b/lib/dalli/server.rb @@ -583,7 +583,7 @@ def flush_response def write(bytes) begin @inprogress = true - result = @sock.write(bytes) + result = @sock.writefull(bytes) @inprogress = false result rescue SystemCallError, Timeout::Error => e diff --git a/lib/dalli/socket.rb b/lib/dalli/socket.rb index 13d2124b..04d4e0d2 100644 --- a/lib/dalli/socket.rb +++ b/lib/dalli/socket.rb @@ -1,5 +1,7 @@ # frozen_string_literal: true +require 'resolv' + module Dalli module Socket module InstanceMethods @@ -21,6 +23,22 @@ def readfull(count) value end + def writefull(bytes) + offset = 0 + while offset < bytes.bytesize + chunk = offset == 0 ? bytes : bytes.byteslice(offset..-1) + result = write_nonblock(chunk, exception: false) + if result == :wait_writable + raise Timeout::Error, "IO timeout: #{safe_options.inspect}" unless IO.select(nil, [self], nil, options[:socket_timeout]) + elsif result == :wait_readable + raise Timeout::Error, "IO timeout: #{safe_options.inspect}" unless IO.select([self], nil, nil, options[:socket_timeout]) + else + offset += result + end + end + offset + end + def read_available value = +"" loop do @@ -43,35 +61,97 @@ def safe_options end end - class TCP < TCPSocket + class TCP < ::Socket include Dalli::Socket::InstanceMethods attr_accessor :options, :server def self.open(host, port, server, options = {}) - Timeout.timeout(options[:socket_timeout]) do - sock = new(host, port) - sock.options = {host: host, port: port}.merge(options) - sock.server = server - sock.setsockopt(::Socket::IPPROTO_TCP, ::Socket::TCP_NODELAY, true) - sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_KEEPALIVE, true) if options[:keepalive] - sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_RCVBUF, options[:rcvbuf]) if options[:rcvbuf] - sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_SNDBUF, options[:sndbuf]) if options[:sndbuf] - sock + addr_info = resolve_address(host, options[:socket_timeout]) + sock = new(addr_info[4], ::Socket::SOCK_STREAM, 0) # addr_info[4] == address family constant (e.g. AF_INET), expressed as an integer + + sock.setsockopt(::Socket::IPPROTO_TCP, ::Socket::TCP_NODELAY, true) + sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_KEEPALIVE, true) if options[:keepalive] + sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_RCVBUF, options[:rcvbuf]) if options[:rcvbuf] + sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_SNDBUF, options[:sndbuf]) if options[:sndbuf] + + sockaddr = ::Socket.pack_sockaddr_in(port, addr_info[3]) # addr_info[3] == IP address string (e.g. "192.168.1.1") + result = sock.connect_nonblock(sockaddr, exception: false) + if result == :wait_writable + unless IO.select(nil, [sock], nil, options[:socket_timeout]) + raise Timeout::Error, "Connection timeout: #{host}:#{port}" + end + begin + sock.connect_nonblock(sockaddr) + rescue Errno::EISCONN + # already connected + end end + + sock.options = { host: host, port: port }.merge(options) + sock.server = server + sock + rescue + sock&.close rescue nil + raise + end + + # Resolve a hostname to structured address info with timeout protection. + # getaddrinfo(3) is a blocking C library call that can block indefinitely + # on unresponsive DNS. For IP addresses (the common case with memcached), + # getaddrinfo returns immediately without DNS and is safe to call directly. + # For hostnames, we use Ruby's Resolv library which is pure Ruby and + # supports timeouts, then pass the resolved IP to getaddrinfo for the + # structured address info the caller expects. + def self.resolve_address(host, timeout) + if ip_address?(host) + return ::Socket.getaddrinfo(host, nil, ::Socket::AF_UNSPEC, ::Socket::SOCK_STREAM).first + end + + dns = Resolv::DNS.new + dns.timeouts = timeout + resolver = Resolv.new([Resolv::Hosts.new, dns]) + resolved_ip = resolver.getaddress(host).to_s + ::Socket.getaddrinfo(resolved_ip, nil, ::Socket::AF_UNSPEC, ::Socket::SOCK_STREAM).first + rescue Resolv::ResolvError => e + raise SocketError, "getaddrinfo: Name or service not known - #{host} (#{e.message})" + ensure + dns&.close end + private_class_method :resolve_address + + # Returns true if host is an IP address (v4 or v6) rather than a hostname. + def self.ip_address?(host) + host.match?(Resolv::IPv4::Regex) || host.include?(':') + end + private_class_method :ip_address? end - class UNIX < UNIXSocket + class UNIX < ::Socket include Dalli::Socket::InstanceMethods attr_accessor :options, :server def self.open(path, server, options = {}) - Timeout.timeout(options[:socket_timeout]) do - sock = new(path) - sock.options = {path: path}.merge(options) - sock.server = server - sock + sock = new(::Socket::AF_UNIX, ::Socket::SOCK_STREAM, 0) + sockaddr = ::Socket.pack_sockaddr_un(path) + + result = sock.connect_nonblock(sockaddr, exception: false) + if result == :wait_writable + unless IO.select(nil, [sock], nil, options[:socket_timeout]) + raise Timeout::Error, "Connection timeout: #{path}" + end + begin + sock.connect_nonblock(sockaddr) + rescue Errno::EISCONN + # already connected + end end + + sock.options = { path: path }.merge(options) + sock.server = server + sock + rescue + sock&.close rescue nil + raise end end end diff --git a/test/test_socket.rb b/test/test_socket.rb index ad2983a5..51c169cc 100644 --- a/test/test_socket.rb +++ b/test/test_socket.rb @@ -3,12 +3,14 @@ class MockSocket include Dalli::Socket::InstanceMethods - attr_accessor :options, :read_results + attr_accessor :options, :read_results, :write_results def initialize(options = {}) @options = options @read_results = [] + @write_results = [] @read_index = 0 + @write_index = 0 end def read_nonblock(_count, exception: true) @@ -16,6 +18,12 @@ def read_nonblock(_count, exception: true) @read_index += 1 result end + + def write_nonblock(_bytes, exception: true) + result = @write_results[@write_index] + @write_index += 1 + result + end end describe 'Dalli::Socket::InstanceMethods' do @@ -81,6 +89,70 @@ def read_nonblock(_count, exception: true) end end + describe '#writefull' do + it 'writes all bytes in a single call' do + sock.write_results = [5] + assert_equal 5, sock.writefull("hello") + end + + it 'handles partial writes across multiple calls' do + sock.write_results = [2, 3] + assert_equal 5, sock.writefull("hello") + end + + it 'retries on :wait_writable when IO.select succeeds' do + sock.write_results = [:wait_writable, 5] + IO.stubs(:select).with(nil, [sock], nil, 1).returns([nil, [sock]]) + assert_equal 5, sock.writefull("hello") + end + + it 'retries on :wait_readable when IO.select succeeds' do + sock.write_results = [:wait_readable, 5] + IO.stubs(:select).with([sock], nil, nil, 1).returns([[sock]]) + assert_equal 5, sock.writefull("hello") + end + + it 'raises Timeout::Error on :wait_writable when IO.select times out' do + sock.write_results = [:wait_writable] + IO.stubs(:select).with(nil, [sock], nil, 1).returns(nil) + assert_raises(Timeout::Error) { sock.writefull("hello") } + end + + it 'raises Timeout::Error on :wait_readable when IO.select times out' do + sock.write_results = [:wait_readable] + IO.stubs(:select).with([sock], nil, nil, 1).returns(nil) + assert_raises(Timeout::Error) { sock.writefull("hello") } + end + + it 'delivers all bytes through a real socket pair' do + s1, s2 = Socket.pair(:UNIX, :STREAM, 0) + s1.extend(Dalli::Socket::InstanceMethods) + def s1.options; { socket_timeout: 5 }; end + + data = "hello world" + result = s1.writefull(data) + s1.close + + assert_equal data.bytesize, result + assert_equal data, s2.read + ensure + s1&.close rescue nil + s2&.close rescue nil + end + + describe 'with credentials' do + let(:sock) { MockSocket.new(socket_timeout: 1, username: 'admin', password: 'secret') } + + it 'excludes credentials from Timeout::Error message' do + sock.write_results = [:wait_writable] + IO.stubs(:select).with(nil, [sock], nil, 1).returns(nil) + error = assert_raises(Timeout::Error) { sock.writefull("hello") } + refute_match(/admin/, error.message) + refute_match(/secret/, error.message) + end + end + end + describe '#read_available' do it 'reads all available data until :wait_readable' do sock.read_results = ["he", "llo", :wait_readable] @@ -179,6 +251,19 @@ def read_nonblock(_count, exception: true) @sock = Dalli::Socket::TCP.open('127.0.0.1', @port, 'my_server', socket_timeout: 5) assert_equal 'my_server', @sock.server end + + it 'raises SocketError for unresolvable hostname' do + assert_raises(SocketError) do + Dalli::Socket::TCP.open('this-host-does-not-exist.invalid', 11211, 'srv', socket_timeout: 1) + end + end + + it 'includes hostname in SocketError message for unresolvable host' do + error = assert_raises(SocketError) do + Dalli::Socket::TCP.open('this-host-does-not-exist.invalid', 11211, 'srv', socket_timeout: 1) + end + assert_match(/this-host-does-not-exist\.invalid/, error.message) + end end describe 'Dalli::Socket::UNIX' do @@ -212,4 +297,10 @@ def read_nonblock(_count, exception: true) @sock = Dalli::Socket::UNIX.open(@path, 'my_server', socket_timeout: 5) assert_equal 'my_server', @sock.server end + + it 'raises Errno::ENOENT for non-existent socket path' do + assert_raises(Errno::ENOENT) do + Dalli::Socket::UNIX.open('/tmp/nonexistent_dalli_test_socket', 'srv', socket_timeout: 1) + end + end end