diff --git a/include/boost/redis/connection.hpp b/include/boost/redis/connection.hpp index 93b7e7fd..ae67fe03 100644 --- a/include/boost/redis/connection.hpp +++ b/include/boost/redis/connection.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -67,7 +68,6 @@ struct connection_impl { Executor, void(system::error_code, std::size_t)>; using health_checker_type = detail::health_checker; - using resp3_handshaker_type = detail::resp3_handshaker; using exec_notifier_type = asio::experimental::channel< Executor, void(system::error_code, std::size_t)>; @@ -81,12 +81,13 @@ struct connection_impl { timer_type reconnect_timer_; // to wait the reconnection period receive_channel_type receive_channel_; health_checker_type health_checker_; - resp3_handshaker_type handshaker_; config cfg_; multiplexer mpx_; connection_logger logger_; read_buffer read_buffer_; + request hello_req_; + generic_response hello_resp_; using executor_type = Executor; @@ -326,6 +327,27 @@ class run_op { using order_t = std::array; + static system::error_code on_hello(connection_impl& conn, system::error_code ec) + { + conn.logger_.on_hello(ec, conn.hello_resp_); + ec = check_hello_response(ec, conn.hello_resp_); + if (ec) { + conn.cancel(operation::run); + } + return ec; + } + + template + auto handshaker(CompletionToken&& token) + { + return conn_->async_exec( + conn_->hello_req_, + any_adapter(conn_->hello_resp_), + asio::deferred([&conn = *this->conn_](system::error_code hello_ec, std::size_t) { + return asio::deferred.values(on_hello(conn, hello_ec)); + }))(std::forward(token)); + } + template auto reader(CompletionToken&& token) { @@ -392,6 +414,9 @@ class run_op { return; } + // Set up the hello request, as it only depends on the config + push_hello(conn_->cfg_, conn_->hello_req_); + for (;;) { // Try to connect BOOST_ASIO_CORO_YIELD @@ -401,6 +426,7 @@ class run_op { if (!ec) { conn_->read_buffer_.clear(); conn_->mpx_.reset(); + clear_response(conn_->hello_resp_); // Note: Order is important here because the writer might // trigger an async_write before the async_hello thereby @@ -408,7 +434,7 @@ class run_op { BOOST_ASIO_CORO_YIELD asio::experimental::make_parallel_group( [this](auto token) { - return conn_->handshaker_.async_hello(*conn_, token); + return this->handshaker(token); }, [this](auto token) { return conn_->health_checker_.async_ping(*conn_, token); @@ -606,7 +632,6 @@ class basic_connection { { impl_->cfg_ = cfg; impl_->health_checker_.set_config(cfg); - impl_->handshaker_.set_config(cfg); impl_->read_buffer_.set_config({cfg.read_buffer_append_size, cfg.max_read_size}); return asio::async_compose( @@ -908,7 +933,6 @@ class basic_connection { executor_type, void(system::error_code, std::size_t)>; using health_checker_type = detail::health_checker; - using resp3_handshaker_type = detail::resp3_handshaker; auto use_ssl() const noexcept { return impl_->cfg_.use_ssl; } diff --git a/include/boost/redis/detail/resp3_handshaker.hpp b/include/boost/redis/detail/resp3_handshaker.hpp index 05edc795..8ba4334e 100644 --- a/include/boost/redis/detail/resp3_handshaker.hpp +++ b/include/boost/redis/detail/resp3_handshaker.hpp @@ -4,110 +4,25 @@ * accompanying file LICENSE.txt) */ -#ifndef BOOST_REDIS_RUNNER_HPP -#define BOOST_REDIS_RUNNER_HPP +#ifndef BOOST_REDIS_RESP3_HANDSHAKER_HPP +#define BOOST_REDIS_RESP3_HANDSHAKER_HPP #include -#include -#include -#include #include #include -#include -#include - -#include - namespace boost::redis::detail { -void push_hello(config const& cfg, request& req); - -// TODO: Can we avoid this whole function whose only purpose is to -// check for an error in the hello response and complete with an error -// so that the parallel group that starts it can exit? -template -struct hello_op { - Handshaker* handshaker_ = nullptr; - ConnectionImpl* conn_ = nullptr; - asio::coroutine coro_{}; - - template - void operator()(Self& self, system::error_code ec = {}, std::size_t = 0) - { - BOOST_ASIO_CORO_REENTER(coro_) - { - handshaker_->add_hello(); - - BOOST_ASIO_CORO_YIELD - conn_->async_exec( - handshaker_->hello_req_, - any_adapter(handshaker_->hello_resp_), - std::move(self)); - conn_->logger_.on_hello(ec, handshaker_->hello_resp_); - - if (ec) { - conn_->cancel(operation::run); - self.complete(ec); - return; - } - - if (handshaker_->has_error_in_response()) { - conn_->cancel(operation::run); - self.complete(error::resp3_hello); - return; - } - - self.complete({}); - } - } -}; - -template -class resp3_handshaker { -public: - void set_config(config const& cfg) { cfg_ = cfg; } - - template - auto async_hello(ConnectionImpl& conn, CompletionToken token) - { - return asio::async_compose( - hello_op{this, &conn}, - token, - conn); - } - -private: - template friend struct hello_op; - - void add_hello() - { - hello_req_.clear(); - if (hello_resp_.has_value()) - hello_resp_.value().clear(); - push_hello(cfg_, hello_req_); - } - - bool has_error_in_response() const noexcept - { - if (!hello_resp_.has_value()) - return true; - - auto f = [](auto const& e) { - switch (e.data_type) { - case resp3::type::simple_error: - case resp3::type::blob_error: return true; - default: return false; - } - }; - - return std::any_of(std::cbegin(hello_resp_.value()), std::cend(hello_resp_.value()), f); - } - - request hello_req_; - generic_response hello_resp_; - config cfg_; -}; +void push_hello(config const& cfg, request& req); // TODO: rename +system::error_code check_hello_response(system::error_code io_ec, const generic_response&); +// TODO: logging should be here, too +inline void clear_response(generic_response& res) +{ + if (res.has_value()) + res->clear(); + else + res.emplace(); +} } // namespace boost::redis::detail diff --git a/include/boost/redis/impl/resp3_handshaker.ipp b/include/boost/redis/impl/resp3_handshaker.ipp index db516419..81f85da4 100644 --- a/include/boost/redis/impl/resp3_handshaker.ipp +++ b/include/boost/redis/impl/resp3_handshaker.ipp @@ -10,6 +10,7 @@ namespace boost::redis::detail { void push_hello(config const& cfg, request& req) { + req.clear(); if (!cfg.username.empty() && !cfg.password.empty() && !cfg.clientname.empty()) req.push("HELLO", "3", "AUTH", cfg.username, cfg.password, "SETNAME", cfg.clientname); else if (cfg.password.empty() && cfg.clientname.empty()) @@ -23,4 +24,24 @@ void push_hello(config const& cfg, request& req) req.push("SELECT", cfg.database_index.value()); } +system::error_code check_hello_response(system::error_code io_ec, const generic_response& resp) +{ + if (io_ec) + return io_ec; + + if (resp.has_error()) + return error::resp3_hello; + + auto f = [](auto const& e) { + switch (e.data_type) { + case resp3::type::simple_error: + case resp3::type::blob_error: return true; + default: return false; + } + }; + + bool has_error = std::any_of(resp->cbegin(), resp->cend(), f); + return has_error ? error::resp3_hello : system::error_code(); +} + } // namespace boost::redis::detail