diff --git a/src/brpc/channel.cpp b/src/brpc/channel.cpp index 0252e97d74..7e6381c7d7 100644 --- a/src/brpc/channel.cpp +++ b/src/brpc/channel.cpp @@ -77,6 +77,8 @@ ChannelSSLOptions* ChannelOptions::mutable_ssl_options() { static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { if (opt.auth == NULL && !opt.has_ssl_options() && + opt.client_host.empty() && + opt.device_name.empty() && opt.connection_group.empty() && opt.hc_option.health_check_path.empty()) { // Returning zeroized result by default is more intuitive for users. @@ -94,6 +96,14 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { buf.append("|conng="); buf.append(opt.connection_group); } + if (!opt.client_host.empty()) { + buf.append("|clih="); + buf.append(opt.client_host); + } + if (!opt.device_name.empty()) { + buf.append("|devn="); + buf.append(opt.device_name); + } if (opt.auth) { buf.append("|auth="); buf.append((char*)&opt.auth, sizeof(opt.auth)); @@ -362,6 +372,13 @@ int Channel::InitSingle(const butil::EndPoint& server_addr_and_port, LOG(ERROR) << "Invalid port=" << port; return -1; } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } _server_address = server_addr_and_port; const ChannelSignature sig = ComputeChannelSignature(_options); std::shared_ptr ssl_ctx; @@ -369,7 +386,9 @@ int Channel::InitSingle(const butil::EndPoint& server_addr_and_port, return -1; } if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig), - &_server_id, ssl_ctx, _options.use_rdma, _options.hc_option) != 0) { + &_server_id, ssl_ctx, _options.use_rdma, + _options.hc_option, client_endpoint, + _options.device_name) != 0) { LOG(ERROR) << "Fail to insert into SocketMap"; return -1; } @@ -397,6 +416,13 @@ int Channel::Init(const char* ns_url, _options.mutable_ssl_options()->sni_name = _service_name; } } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } std::unique_ptr lb(new (std::nothrow) LoadBalancerWithNaming); if (NULL == lb) { @@ -409,6 +435,8 @@ int Channel::Init(const char* ns_url, ns_opt.use_rdma = _options.use_rdma; ns_opt.channel_signature = ComputeChannelSignature(_options); ns_opt.hc_option = _options.hc_option; + ns_opt.client_endpoint = client_endpoint; + ns_opt.device_name = _options.device_name; if (CreateSocketSSLContext(_options, &ns_opt.ssl_ctx) != 0) { return -1; } diff --git a/src/brpc/channel.h b/src/brpc/channel.h index c970209b3a..0f349ac6fe 100644 --- a/src/brpc/channel.h +++ b/src/brpc/channel.h @@ -148,6 +148,16 @@ struct ChannelOptions { // Its priority is higher than FLAGS_health_check_path and FLAGS_health_check_timeout_ms. // When it is not set, FLAGS_health_check_path and FLAGS_health_check_timeout_ms will take effect. HealthCheckOption hc_option; + + // IP address or host name of the client. + // if the client_host is "", the client IP address is determined by the OS. + // Default: "" + std::string client_host; + + // The device name of the client's network adapter. + // if the device_name is "", the flow control is determined by the OS. + // Default: "" + std::string device_name; private: // SSLOptions is large and not often used, allocate it on heap to // prevent ChannelOptions from being bloated in most cases. diff --git a/src/brpc/details/naming_service_thread.cpp b/src/brpc/details/naming_service_thread.cpp index 341ca35b09..a75db265b0 100644 --- a/src/brpc/details/naming_service_thread.cpp +++ b/src/brpc/details/naming_service_thread.cpp @@ -126,7 +126,9 @@ void NamingServiceThread::Actions::ResetServers( // to pick those Sockets with the right settings during OnAddedServers const SocketMapKey key(_added[i], _owner->_options.channel_signature); CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx, - _owner->_options.use_rdma, _owner->_options.hc_option)); + _owner->_options.use_rdma, _owner->_options.hc_option, + _owner->_options.client_endpoint, + _owner->_options.device_name)); _added_sockets.push_back(tagged_id); } diff --git a/src/brpc/details/naming_service_thread.h b/src/brpc/details/naming_service_thread.h index 1745e5f267..34a29a8622 100644 --- a/src/brpc/details/naming_service_thread.h +++ b/src/brpc/details/naming_service_thread.h @@ -53,6 +53,8 @@ struct GetNamingServiceThreadOptions { HealthCheckOption hc_option; ChannelSignature channel_signature; std::shared_ptr ssl_ctx; + butil::EndPoint client_endpoint; + std::string device_name; }; // A dedicated thread to map a name to ServerIds diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 9490650b78..e431aceff9 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -728,7 +728,8 @@ int Socket::OnCreated(const SocketOptions& options) { _keytable_pool = options.keytable_pool; _tos = 0; _remote_side = options.remote_side; - _local_side = butil::EndPoint(); + _local_side = options.local_side; + _device_name = options.device_name; _on_edge_triggered_events = options.on_edge_triggered_events; _user = options.user; _conn = options.conn; @@ -1296,7 +1297,25 @@ int Socket::Connect(const timespec* abstime, CHECK_EQ(0, butil::make_close_on_exec(sockfd)); // We need to do async connect (to manage the timeout by ourselves). CHECK_EQ(0, butil::make_non_blocking(sockfd)); - + if (!_device_name.empty()) { + if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, + _device_name.c_str(), _device_name.size()) < 0) { + PLOG(ERROR) << "Fail to set SO_BINDTODEVICE of fd=" << sockfd + << " to device_name=" << _device_name; + return -1; + } + } + if (local_side().ip != butil::IP_ANY) { + struct sockaddr_storage cli_addr; + if (butil::endpoint2sockaddr(local_side(), &cli_addr, &addr_size) != 0) { + PLOG(ERROR) << "Fail to get client sockaddr"; + return -1; + } + if (::bind(sockfd, (struct sockaddr*)&cli_addr, addr_size) != 0) { + PLOG(ERROR) << "Fail to bind client socket, errno=" << strerror(errno); + return -1; + } + } const int rc = ::connect( sockfd, (struct sockaddr*)&serv_addr, addr_size); if (rc != 0 && errno != EINPROGRESS) { @@ -2811,6 +2830,7 @@ int Socket::GetPooledSocket(SocketUniquePtr* pooled_socket) { if (socket_pool == NULL) { SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; opt.initial_ssl_ctx = _ssl_ctx; @@ -2912,6 +2932,7 @@ int Socket::GetShortSocket(SocketUniquePtr* short_socket) { SocketId id; SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; opt.initial_ssl_ctx = _ssl_ctx; diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 03ad43f867..a3e2323056 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -250,6 +250,8 @@ struct SocketOptions { // user->BeforeRecycle() before recycling. int fd{-1}; butil::EndPoint remote_side; + butil::EndPoint local_side; + std::string device_name; // If `connect_on_create' is true and `fd' is less than 0, // a client connection will be established to remote_side() // regarding deadline `connect_abstime' when Socket is being created. @@ -830,6 +832,9 @@ friend void DereferenceSocket(Socket*); // Address of self. Initialized in ResetFileDescriptor(). butil::EndPoint _local_side; + // The device name of the client's network adapter. + std::string _device_name; + // Called when edge-triggered events happened on `_fd'. Read comments // of EventDispatcher::AddConsumer (event_dispatcher.h) // carefully before implementing the callback. diff --git a/src/brpc/socket_map.cpp b/src/brpc/socket_map.cpp index 14bea71db5..16e6986394 100644 --- a/src/brpc/socket_map.cpp +++ b/src/brpc/socket_map.cpp @@ -92,8 +92,10 @@ SocketMap* get_or_new_client_side_socket_map() { int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option) { - return get_or_new_client_side_socket_map()->Insert(key, id, ssl_ctx, use_rdma, hc_option); + const HealthCheckOption& hc_option, + const butil::EndPoint& client_endpoint, + const std::string& device_name) { + return get_or_new_client_side_socket_map()->Insert(key, id, ssl_ctx, use_rdma, hc_option, client_endpoint, device_name); } int SocketMapFind(const SocketMapKey& key, SocketId* id) { @@ -229,7 +231,9 @@ void SocketMap::ShowSocketMapInBvarIfNeed() { int SocketMap::Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option) { + const HealthCheckOption& hc_option, + const butil::EndPoint& client_endpoint, + const std::string& device_name) { ShowSocketMapInBvarIfNeed(); std::unique_lock mu(_mutex); @@ -251,6 +255,8 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id, SocketId tmp_id; SocketOptions opt; opt.remote_side = key.peer.addr; + opt.local_side = client_endpoint; + opt.device_name = device_name; opt.initial_ssl_ctx = ssl_ctx; opt.use_rdma = use_rdma; opt.hc_option = hc_option; diff --git a/src/brpc/socket_map.h b/src/brpc/socket_map.h index b0d542e78e..c939bf1bc6 100644 --- a/src/brpc/socket_map.h +++ b/src/brpc/socket_map.h @@ -82,18 +82,22 @@ struct SocketMapKeyHasher { int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option); + const HealthCheckOption& hc_option, + const butil::EndPoint& client_endpoint, + const std::string& device_name); inline int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { HealthCheckOption hc_option; - return SocketMapInsert(key, id, ssl_ctx, false, hc_option); + butil::EndPoint endpoint; + return SocketMapInsert(key, id, ssl_ctx, false, hc_option, endpoint, ""); } inline int SocketMapInsert(const SocketMapKey& key, SocketId* id) { std::shared_ptr empty_ptr; HealthCheckOption hc_option; - return SocketMapInsert(key, id, empty_ptr, false, hc_option); + butil::EndPoint endpoint; + return SocketMapInsert(key, id, empty_ptr, false, hc_option, endpoint, ""); } // Find the SocketId associated with `key'. @@ -155,17 +159,21 @@ class SocketMap { int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option); + const HealthCheckOption& hc_option, + const butil::EndPoint& client_endpoint, + const std::string& device_name); int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { HealthCheckOption hc_option; - return Insert(key, id, ssl_ctx, false, hc_option); + butil::EndPoint endpoint; + return Insert(key, id, ssl_ctx, false, hc_option, endpoint, ""); } int Insert(const SocketMapKey& key, SocketId* id) { std::shared_ptr empty_ptr; HealthCheckOption hc_option; - return Insert(key, id, empty_ptr, false, hc_option); + butil::EndPoint endpoint; + return Insert(key, id, empty_ptr, false, hc_option, endpoint, ""); } void Remove(const SocketMapKey& key, SocketId expected_id); diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 4a774fab2a..8508a7986c 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -2070,4 +2070,49 @@ TEST_F(ServerTest, auth) { ASSERT_EQ(0, server.Join()); } +void TestClientHost(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed, + brpc::ChannelOptions& copt) { + brpc::Channel chan; + copt.max_retry = 0; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +TEST_F(ServerTest, bind_client_host_and_network_device) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions opt; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Controller cntl; + brpc::ChannelOptions copt; + copt.client_host = "localhost"; + copt.device_name = "lo"; + std::vector connection_types = { + brpc::CONNECTION_TYPE_SINGLE, + brpc::CONNECTION_TYPE_POOLED, + brpc::CONNECTION_TYPE_SHORT + }; + for (auto connect_type : connection_types) { + copt.connection_type = connect_type; + TestClientHost(ep, cntl, 0, false, copt); + cntl.Reset(); + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace