Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/dalli/server.rb
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def flush_response
def write(bytes)
begin
@inprogress = true
result = @sock.write(bytes)
result = @sock.writefull(bytes)
@inprogress = false
result
rescue SystemCallError, Timeout::Error => e
Expand Down
112 changes: 96 additions & 16 deletions lib/dalli/socket.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# frozen_string_literal: true

require 'resolv'

module Dalli
module Socket
module InstanceMethods
Expand All @@ -21,6 +23,22 @@ def readfull(count)
value
end

def writefull(bytes)
offset = 0
while offset < bytes.bytesize
chunk = offset == 0 ? bytes : bytes.byteslice(offset..-1)
result = write_nonblock(chunk, exception: false)
if result == :wait_writable
raise Timeout::Error, "IO timeout: #{safe_options.inspect}" unless IO.select(nil, [self], nil, options[:socket_timeout])
elsif result == :wait_readable
raise Timeout::Error, "IO timeout: #{safe_options.inspect}" unless IO.select([self], nil, nil, options[:socket_timeout])
else
offset += result
end
end
offset
end

def read_available
value = +""
loop do
Expand All @@ -43,35 +61,97 @@ def safe_options
end
end

class TCP < TCPSocket
class TCP < ::Socket
include Dalli::Socket::InstanceMethods
attr_accessor :options, :server

def self.open(host, port, server, options = {})
Timeout.timeout(options[:socket_timeout]) do
sock = new(host, port)
sock.options = {host: host, port: port}.merge(options)
sock.server = server
sock.setsockopt(::Socket::IPPROTO_TCP, ::Socket::TCP_NODELAY, true)
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_KEEPALIVE, true) if options[:keepalive]
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_RCVBUF, options[:rcvbuf]) if options[:rcvbuf]
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_SNDBUF, options[:sndbuf]) if options[:sndbuf]
sock
addr_info = resolve_address(host, options[:socket_timeout])
sock = new(addr_info[4], ::Socket::SOCK_STREAM, 0) # addr_info[4] == address family constant (e.g. AF_INET), expressed as an integer

sock.setsockopt(::Socket::IPPROTO_TCP, ::Socket::TCP_NODELAY, true)
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_KEEPALIVE, true) if options[:keepalive]
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_RCVBUF, options[:rcvbuf]) if options[:rcvbuf]
sock.setsockopt(::Socket::SOL_SOCKET, ::Socket::SO_SNDBUF, options[:sndbuf]) if options[:sndbuf]

sockaddr = ::Socket.pack_sockaddr_in(port, addr_info[3]) # addr_info[3] == IP address string (e.g. "192.168.1.1")
result = sock.connect_nonblock(sockaddr, exception: false)
if result == :wait_writable
unless IO.select(nil, [sock], nil, options[:socket_timeout])
raise Timeout::Error, "Connection timeout: #{host}:#{port}"
end
begin
sock.connect_nonblock(sockaddr)
rescue Errno::EISCONN
# already connected
end
end

sock.options = { host: host, port: port }.merge(options)
sock.server = server
sock
rescue
sock&.close rescue nil
raise
end

# Resolve a hostname to structured address info with timeout protection.
# getaddrinfo(3) is a blocking C library call that can block indefinitely
# on unresponsive DNS. For IP addresses (the common case with memcached),
# getaddrinfo returns immediately without DNS and is safe to call directly.
# For hostnames, we use Ruby's Resolv library which is pure Ruby and
# supports timeouts, then pass the resolved IP to getaddrinfo for the
# structured address info the caller expects.
def self.resolve_address(host, timeout)
if ip_address?(host)
return ::Socket.getaddrinfo(host, nil, ::Socket::AF_UNSPEC, ::Socket::SOCK_STREAM).first
end

dns = Resolv::DNS.new
dns.timeouts = timeout
resolver = Resolv.new([Resolv::Hosts.new, dns])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk how costly it is but i generally like the idea of caching resolved hosts for a while instead of resolving on every conn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At Braze, we shouldn't ever get down to this line because we use IP addresses to connect to memcached and will return on the guard clause. Also connection creation is not a hot path, so optimizing here is not so important.

Also also, this bit from Claude: if we were to use DNS, there is a risk of caching and using stale hosts if K8s should rotate DNS records to redirect traffic. Leaving as-is for now

resolved_ip = resolver.getaddress(host).to_s
::Socket.getaddrinfo(resolved_ip, nil, ::Socket::AF_UNSPEC, ::Socket::SOCK_STREAM).first
rescue Resolv::ResolvError => e
raise SocketError, "getaddrinfo: Name or service not known - #{host} (#{e.message})"
ensure
dns&.close
end
private_class_method :resolve_address

# Returns true if host is an IP address (v4 or v6) rather than a hostname.
def self.ip_address?(host)
host.match?(/\A\d{1,3}(\.\d{1,3}){3}\z/) || host.include?(':')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Resolv::IPv4 and Ipv6 have Regexes, might be useful?

  host.match?(Resolv::IPv4::Regex) || host.match?(Resolv::IPv6::Regex)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea
host.include?(':') already captures IPv6 format, but is looser and will capture edge cases that Resolv::IPv6::Regex will miss (e.g. a scoped address like fe80::1%eth0)
Gonna move forward with

host.match?(Resolv::IPv4::Regex) || host.include?(':')

end
private_class_method :ip_address?
end

class UNIX < UNIXSocket
class UNIX < ::Socket
include Dalli::Socket::InstanceMethods
attr_accessor :options, :server

def self.open(path, server, options = {})
Timeout.timeout(options[:socket_timeout]) do
sock = new(path)
sock.options = {path: path}.merge(options)
sock.server = server
sock
sock = new(::Socket::AF_UNIX, ::Socket::SOCK_STREAM, 0)
sockaddr = ::Socket.pack_sockaddr_un(path)

result = sock.connect_nonblock(sockaddr, exception: false)
if result == :wait_writable
unless IO.select(nil, [sock], nil, options[:socket_timeout])
raise Timeout::Error, "Connection timeout: #{path}"
end
begin
sock.connect_nonblock(sockaddr)
rescue Errno::EISCONN
# already connected
end
end

sock.options = { path: path }.merge(options)
sock.server = server
sock
rescue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rescue => e so that we at least only match StdErrors? this sounds too catch-allish

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, rescue and rescue => e are the same, with the latter capturing the exception to a var e

sock&.close rescue nil
raise
end
end
end
Expand Down
93 changes: 92 additions & 1 deletion test/test_socket.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@

class MockSocket
include Dalli::Socket::InstanceMethods
attr_accessor :options, :read_results
attr_accessor :options, :read_results, :write_results

def initialize(options = {})
@options = options
@read_results = []
@write_results = []
@read_index = 0
@write_index = 0
end

def read_nonblock(_count, exception: true)
result = @read_results[@read_index]
@read_index += 1
result
end

def write_nonblock(_bytes, exception: true)
result = @write_results[@write_index]
@write_index += 1
result
end
end

describe 'Dalli::Socket::InstanceMethods' do
Expand Down Expand Up @@ -81,6 +89,70 @@ def read_nonblock(_count, exception: true)
end
end

describe '#writefull' do
it 'writes all bytes in a single call' do
sock.write_results = [5]
assert_equal 5, sock.writefull("hello")
end

it 'handles partial writes across multiple calls' do
sock.write_results = [2, 3]
assert_equal 5, sock.writefull("hello")
end

it 'retries on :wait_writable when IO.select succeeds' do
sock.write_results = [:wait_writable, 5]
IO.stubs(:select).with(nil, [sock], nil, 1).returns([nil, [sock]])
assert_equal 5, sock.writefull("hello")
end

it 'retries on :wait_readable when IO.select succeeds' do
sock.write_results = [:wait_readable, 5]
IO.stubs(:select).with([sock], nil, nil, 1).returns([[sock]])
assert_equal 5, sock.writefull("hello")
end

it 'raises Timeout::Error on :wait_writable when IO.select times out' do
sock.write_results = [:wait_writable]
IO.stubs(:select).with(nil, [sock], nil, 1).returns(nil)
assert_raises(Timeout::Error) { sock.writefull("hello") }
end

it 'raises Timeout::Error on :wait_readable when IO.select times out' do
sock.write_results = [:wait_readable]
IO.stubs(:select).with([sock], nil, nil, 1).returns(nil)
assert_raises(Timeout::Error) { sock.writefull("hello") }
end

it 'delivers all bytes through a real socket pair' do
s1, s2 = Socket.pair(:UNIX, :STREAM, 0)
s1.extend(Dalli::Socket::InstanceMethods)
def s1.options; { socket_timeout: 5 }; end

data = "hello world"
result = s1.writefull(data)
s1.close

assert_equal data.bytesize, result
assert_equal data, s2.read
ensure
s1&.close rescue nil
s2&.close rescue nil
end

describe 'with credentials' do
let(:sock) { MockSocket.new(socket_timeout: 1, username: 'admin', password: 'secret') }

it 'excludes credentials from Timeout::Error message' do
sock.write_results = [:wait_writable]
IO.stubs(:select).with(nil, [sock], nil, 1).returns(nil)
error = assert_raises(Timeout::Error) { sock.writefull("hello") }
refute_match(/admin/, error.message)
refute_match(/secret/, error.message)
end
end
end

describe '#read_available' do
it 'reads all available data until :wait_readable' do
sock.read_results = ["he", "llo", :wait_readable]
Expand Down Expand Up @@ -179,6 +251,19 @@ def read_nonblock(_count, exception: true)
@sock = Dalli::Socket::TCP.open('127.0.0.1', @port, 'my_server', socket_timeout: 5)
assert_equal 'my_server', @sock.server
end

it 'raises SocketError for unresolvable hostname' do
assert_raises(SocketError) do
Dalli::Socket::TCP.open('this-host-does-not-exist.invalid', 11211, 'srv', socket_timeout: 1)
end
end

it 'includes hostname in SocketError message for unresolvable host' do
error = assert_raises(SocketError) do
Dalli::Socket::TCP.open('this-host-does-not-exist.invalid', 11211, 'srv', socket_timeout: 1)
end
assert_match(/this-host-does-not-exist\.invalid/, error.message)
end
end

describe 'Dalli::Socket::UNIX' do
Expand Down Expand Up @@ -212,4 +297,10 @@ def read_nonblock(_count, exception: true)
@sock = Dalli::Socket::UNIX.open(@path, 'my_server', socket_timeout: 5)
assert_equal 'my_server', @sock.server
end

it 'raises Errno::ENOENT for non-existent socket path' do
assert_raises(Errno::ENOENT) do
Dalli::Socket::UNIX.open('/tmp/nonexistent_dalli_test_socket', 'srv', socket_timeout: 1)
end
end
end