Skip to content

Commit 30c7579

Browse files
authored
listener: add generic socket interface selection and filter cleanup (#40304)
This PR introduces a mechanism for custom address types to specify their preferred socket interface. It also adds cleanup capabilities for listener filters. We need this change for custom connection types (like reverse tunnels) to integrate cleanly without requiring any hardcoded logic in any of the core components. --- **Risk Level:** Low **Testing:** Added Unit Tests **Docs Changes:** N/A **Release Notes:** N/A Signed-off-by: Rohit Agrawal <[email protected]>
1 parent c34a837 commit 30c7579

File tree

6 files changed

+250
-0
lines changed

6 files changed

+250
-0
lines changed

envoy/network/filter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,12 @@ class ListenerFilter {
440440
*/
441441
virtual FilterStatus onData(Network::ListenerFilterBuffer& buffer) PURE;
442442

443+
/**
444+
* Called when the connection is closed. Only the current filter that has stopped filter
445+
* chain iteration will get the callback.
446+
*/
447+
virtual void onClose() {};
448+
443449
/**
444450
* Return the size of data the filter want to inspect from the connection.
445451
* The size can be increased after filter need to inspect more data.

source/common/listener_manager/active_tcp_socket.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ void ActiveTcpSocket::createListenerFilterBuffer() {
7474
listener_filter_buffer_ = std::make_unique<Network::ListenerFilterBufferImpl>(
7575
socket_->ioHandle(), listener_.dispatcher(),
7676
[this](bool error) {
77+
(*iter_)->onClose();
7778
socket_->ioHandle().close();
7879
if (error) {
7980
listener_.stats_.downstream_listener_filter_error_.inc();

source/common/listener_manager/active_tcp_socket.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class ActiveTcpSocket : public Network::ListenerFilterManager,
5353
}
5454

5555
size_t maxReadBytes() const override { return listener_filter_->maxReadBytes(); }
56+
57+
void onClose() override { return listener_filter_->onClose(); }
5658
};
5759
using ListenerFilterWrapperPtr = std::unique_ptr<GenericListenerFilter>;
5860

source/common/listener_manager/listener_manager_impl.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "source/common/network/filter_matcher.h"
2222
#include "source/common/network/io_socket_handle_impl.h"
2323
#include "source/common/network/listen_socket_impl.h"
24+
#include "source/common/network/socket_interface.h"
2425
#include "source/common/network/socket_option_factory.h"
2526
#include "source/common/network/utility.h"
2627
#include "source/common/protobuf/utility.h"
@@ -316,6 +317,22 @@ absl::StatusOr<Network::SocketSharedPtr> ProdListenerComponentFactory::createLis
316317
ASSERT(socket_type == Network::Socket::Type::Stream ||
317318
socket_type == Network::Socket::Type::Datagram);
318319

320+
// Use the address's socket interface for socket creation.
321+
const Network::SocketInterface& socket_interface = address->socketInterface();
322+
const Network::SocketInterface& default_interface = Network::SocketInterfaceSingleton::get();
323+
324+
// Check if this address specifies a custom socket interface.
325+
if (&socket_interface != &default_interface) {
326+
ENVOY_LOG(debug, "creating socket using custom interface for address: {}",
327+
address->logicalName());
328+
auto io_handle = socket_interface.socket(socket_type, address, creation_options);
329+
if (!io_handle) {
330+
return absl::InvalidArgumentError("failed to create socket using custom interface");
331+
}
332+
return std::make_shared<Network::TcpListenSocket>(std::move(io_handle), address, options);
333+
}
334+
335+
// Continue with standard socket creation for addresses using the default interface.
319336
// First we try to get the socket from our parent if applicable in each case below.
320337
if (address->type() == Network::Address::Type::Pipe) {
321338
if (socket_type != Network::Socket::Type::Stream) {
@@ -401,6 +418,7 @@ ListenerManagerImpl::ListenerManagerImpl(Instance& server,
401418
for (uint32_t i = 0; i < server.options().concurrency(); i++) {
402419
workers_.emplace_back(worker_factory.createWorker(
403420
i, server.overloadManager(), server.nullOverloadManager(), absl::StrCat("worker_", i)));
421+
ENVOY_LOG(debug, "starting worker: {}", i);
404422
}
405423
}
406424

test/common/listener_manager/listener_manager_impl_test.cc

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8242,6 +8242,228 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, EmptyConnectionBalanceConfig) {
82428242
#endif
82438243
}
82448244

8245+
// Test mock socket interface for custom address testing.
8246+
class TestCustomSocketInterface : public Network::SocketInterfaceBase {
8247+
public:
8248+
TestCustomSocketInterface() = default;
8249+
8250+
// Network::SocketInterface
8251+
Network::IoHandlePtr socket(Network::Socket::Type socket_type, Network::Address::Type addr_type,
8252+
Network::Address::IpVersion version, bool socket_v6only,
8253+
const Network::SocketCreationOptions& options) const override {
8254+
UNREFERENCED_PARAMETER(socket_v6only);
8255+
UNREFERENCED_PARAMETER(options);
8256+
// Create a regular socket for testing
8257+
if (socket_type == Network::Socket::Type::Stream && addr_type == Network::Address::Type::Ip) {
8258+
int domain = (version == Network::Address::IpVersion::v4) ? AF_INET : AF_INET6;
8259+
int sock_fd = ::socket(domain, SOCK_STREAM, 0);
8260+
if (sock_fd == -1) {
8261+
return nullptr;
8262+
}
8263+
was_called_ = true;
8264+
return std::make_unique<Network::IoSocketHandleImpl>(sock_fd);
8265+
}
8266+
return nullptr;
8267+
}
8268+
8269+
Network::IoHandlePtr socket(Network::Socket::Type socket_type,
8270+
const Network::Address::InstanceConstSharedPtr addr,
8271+
const Network::SocketCreationOptions& options) const override {
8272+
// Delegate to the other socket method
8273+
return socket(socket_type, addr->type(),
8274+
addr->ip() ? addr->ip()->version() : Network::Address::IpVersion::v4, false,
8275+
options);
8276+
}
8277+
8278+
bool ipFamilySupported(int domain) override { return domain == AF_INET || domain == AF_INET6; }
8279+
8280+
// Server::Configuration::BootstrapExtensionFactory
8281+
Server::BootstrapExtensionPtr
8282+
createBootstrapExtension(const Protobuf::Message& config,
8283+
Server::Configuration::ServerFactoryContext& context) override {
8284+
UNREFERENCED_PARAMETER(config);
8285+
UNREFERENCED_PARAMETER(context);
8286+
return nullptr; // Not used in test
8287+
}
8288+
8289+
ProtobufTypes::MessagePtr createEmptyConfigProto() override {
8290+
return nullptr; // Not used in test
8291+
}
8292+
8293+
std::string name() const override { return "test.custom.socket.interface"; }
8294+
8295+
// Test helper
8296+
bool wasCalled() const { return was_called_; }
8297+
void resetCalled() { was_called_ = false; }
8298+
8299+
private:
8300+
mutable bool was_called_{false};
8301+
};
8302+
8303+
// Test address that returns a custom socket interface
8304+
class TestCustomAddress : public Network::Address::Instance {
8305+
public:
8306+
TestCustomAddress(const Network::SocketInterface& custom_interface)
8307+
: address_string_("127.0.0.1:0"), logical_name_("custom://test-address"),
8308+
custom_interface_(custom_interface),
8309+
ipv4_instance_(std::make_shared<Network::Address::Ipv4Instance>("127.0.0.1", 0)) {}
8310+
8311+
// Network::Address::Instance
8312+
bool operator==(const Instance& rhs) const override { return address_string_ == rhs.asString(); }
8313+
Network::Address::Type type() const override { return Network::Address::Type::Ip; }
8314+
const std::string& asString() const override { return address_string_; }
8315+
absl::string_view asStringView() const override { return address_string_; }
8316+
const std::string& logicalName() const override { return logical_name_; }
8317+
const Network::Address::Ip* ip() const override { return ipv4_instance_->ip(); }
8318+
const Network::Address::Pipe* pipe() const override { return nullptr; }
8319+
const Network::Address::EnvoyInternalAddress* envoyInternalAddress() const override {
8320+
return nullptr;
8321+
}
8322+
absl::optional<std::string> networkNamespace() const override { return absl::nullopt; }
8323+
const sockaddr* sockAddr() const override { return ipv4_instance_->sockAddr(); }
8324+
socklen_t sockAddrLen() const override { return ipv4_instance_->sockAddrLen(); }
8325+
absl::string_view addressType() const override { return "test_custom"; }
8326+
8327+
// Return the custom socket interface
8328+
const Network::SocketInterface& socketInterface() const override { return custom_interface_; }
8329+
8330+
private:
8331+
std::string address_string_;
8332+
std::string logical_name_;
8333+
const Network::SocketInterface& custom_interface_;
8334+
Network::Address::InstanceConstSharedPtr ipv4_instance_;
8335+
};
8336+
8337+
// Test address that returns the default socket interface
8338+
class TestDefaultAddress : public Network::Address::Instance {
8339+
public:
8340+
TestDefaultAddress()
8341+
: address_string_("127.0.0.1:0"), logical_name_("default://test-address"),
8342+
ipv4_instance_(std::make_shared<Network::Address::Ipv4Instance>("127.0.0.1", 0)) {}
8343+
8344+
// Network::Address::Instance
8345+
bool operator==(const Instance& rhs) const override { return address_string_ == rhs.asString(); }
8346+
Network::Address::Type type() const override { return Network::Address::Type::Ip; }
8347+
const std::string& asString() const override { return address_string_; }
8348+
absl::string_view asStringView() const override { return address_string_; }
8349+
const std::string& logicalName() const override { return logical_name_; }
8350+
const Network::Address::Ip* ip() const override { return ipv4_instance_->ip(); }
8351+
const Network::Address::Pipe* pipe() const override { return nullptr; }
8352+
const Network::Address::EnvoyInternalAddress* envoyInternalAddress() const override {
8353+
return nullptr;
8354+
}
8355+
absl::optional<std::string> networkNamespace() const override { return absl::nullopt; }
8356+
const sockaddr* sockAddr() const override { return ipv4_instance_->sockAddr(); }
8357+
socklen_t sockAddrLen() const override { return ipv4_instance_->sockAddrLen(); }
8358+
absl::string_view addressType() const override { return "test_default"; }
8359+
8360+
// Return the default socket interface
8361+
const Network::SocketInterface& socketInterface() const override {
8362+
return Network::SocketInterfaceSingleton::get();
8363+
}
8364+
8365+
private:
8366+
std::string address_string_;
8367+
std::string logical_name_;
8368+
Network::Address::InstanceConstSharedPtr ipv4_instance_;
8369+
};
8370+
8371+
TEST_P(ListenerManagerImplTest, CustomSocketInterfaceIsUsedWhenAddressSpecifiesIt) {
8372+
auto custom_interface = std::make_unique<TestCustomSocketInterface>();
8373+
TestCustomSocketInterface* custom_interface_ptr = custom_interface.get();
8374+
8375+
auto custom_address = std::make_shared<TestCustomAddress>(*custom_interface);
8376+
8377+
// Create listener factory to test the implementation
8378+
ProdListenerComponentFactory real_listener_factory(server_);
8379+
8380+
Network::Socket::OptionsSharedPtr options = nullptr;
8381+
Network::SocketCreationOptions creation_options;
8382+
8383+
// Verify that the custom address returns the custom interface
8384+
EXPECT_NE(&custom_address->socketInterface(), &Network::SocketInterfaceSingleton::get());
8385+
8386+
// The listener factory should use the custom socket interface
8387+
auto socket_result = real_listener_factory.createListenSocket(
8388+
custom_address, Network::Socket::Type::Stream, options,
8389+
ListenerComponentFactory::BindType::NoBind, creation_options, 0 /* worker_index */);
8390+
8391+
// The socket creation should succeed
8392+
EXPECT_TRUE(socket_result.ok());
8393+
if (socket_result.ok()) {
8394+
auto socket = socket_result.value();
8395+
EXPECT_NE(socket, nullptr);
8396+
// Verify the socket was created with the expected address
8397+
EXPECT_EQ(socket->connectionInfoProvider().localAddress()->logicalName(),
8398+
custom_address->logicalName());
8399+
}
8400+
8401+
// Verify the custom interface was actually called
8402+
EXPECT_TRUE(custom_interface_ptr->wasCalled());
8403+
}
8404+
8405+
TEST_P(ListenerManagerImplTest, DefaultSocketInterfaceIsUsedWhenAddressUsesDefault) {
8406+
auto default_address = std::make_shared<TestDefaultAddress>();
8407+
8408+
// Create listener factory to test the implementation
8409+
ProdListenerComponentFactory real_listener_factory(server_);
8410+
8411+
Network::Socket::OptionsSharedPtr options = nullptr;
8412+
Network::SocketCreationOptions creation_options;
8413+
8414+
// Verify that the default address returns the default interface
8415+
EXPECT_EQ(&default_address->socketInterface(), &Network::SocketInterfaceSingleton::get());
8416+
8417+
// The listener factory should use the standard socket creation path
8418+
auto socket_result = real_listener_factory.createListenSocket(
8419+
default_address, Network::Socket::Type::Stream, options,
8420+
ListenerComponentFactory::BindType::NoBind, creation_options, 0 /* worker_index */);
8421+
8422+
// The socket creation should succeed
8423+
EXPECT_TRUE(socket_result.ok());
8424+
if (socket_result.ok()) {
8425+
auto socket = socket_result.value();
8426+
EXPECT_NE(socket, nullptr);
8427+
// Verify the socket was created with the expected address
8428+
EXPECT_EQ(socket->connectionInfoProvider().localAddress()->logicalName(),
8429+
default_address->logicalName());
8430+
}
8431+
}
8432+
8433+
TEST_P(ListenerManagerImplTest, CustomSocketInterfaceFailureIsHandledGracefully) {
8434+
// Create a failing custom socket interface
8435+
class FailingCustomSocketInterface : public TestCustomSocketInterface {
8436+
public:
8437+
Network::IoHandlePtr socket(Network::Socket::Type socket_type,
8438+
const Network::Address::InstanceConstSharedPtr addr,
8439+
const Network::SocketCreationOptions& options) const override {
8440+
UNREFERENCED_PARAMETER(socket_type);
8441+
UNREFERENCED_PARAMETER(addr);
8442+
UNREFERENCED_PARAMETER(options);
8443+
// Always return nullptr to simulate failure
8444+
return nullptr;
8445+
}
8446+
};
8447+
8448+
auto failing_interface = std::make_unique<FailingCustomSocketInterface>();
8449+
auto custom_address = std::make_shared<TestCustomAddress>(*failing_interface);
8450+
8451+
// Create listener factory to test the implementation
8452+
ProdListenerComponentFactory real_listener_factory(server_);
8453+
8454+
Network::Socket::OptionsSharedPtr options = nullptr;
8455+
Network::SocketCreationOptions creation_options;
8456+
8457+
// The listener factory should handle the failure gracefully
8458+
auto socket_result = real_listener_factory.createListenSocket(
8459+
custom_address, Network::Socket::Type::Stream, options,
8460+
ListenerComponentFactory::BindType::NoBind, creation_options, 0 /* worker_index */);
8461+
8462+
// The socket creation should fail with the expected error
8463+
EXPECT_FALSE(socket_result.ok());
8464+
EXPECT_EQ(socket_result.status().message(), "failed to create socket using custom interface");
8465+
}
8466+
82458467
INSTANTIATE_TEST_SUITE_P(Matcher, ListenerManagerImplTest, ::testing::Values(false));
82468468
INSTANTIATE_TEST_SUITE_P(Matcher, ListenerManagerImplWithRealFiltersTest,
82478469
::testing::Values(false, true));

test/mocks/network/mocks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ class MockListenerFilter : public ListenerFilter {
273273
MOCK_METHOD(void, destroy_, ());
274274
MOCK_METHOD(Network::FilterStatus, onAccept, (ListenerFilterCallbacks&));
275275
MOCK_METHOD(Network::FilterStatus, onData, (Network::ListenerFilterBuffer&));
276+
MOCK_METHOD(void, onClose, ());
276277

277278
size_t listener_filter_max_read_bytes_{0};
278279
};

0 commit comments

Comments
 (0)