diff --git a/.gitignore b/.gitignore index a4cacc7732..21839782df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,30 @@ tags +# Ignore executables (no extension) but not source files example/server +!example/server.* example/client +!example/client.* example/hello +!example/hello.* example/simplecli +!example/simplecli.* example/simplesvr +!example/simplesvr.* example/benchmark +!example/benchmark.* example/redirect -example/sse* +!example/redirect.* +example/ssecli +example/ssesvr example/upload +!example/upload.* example/one_time_request +!example/one_time_request.* example/server_and_client +!example/server_and_client.* +example/accept_header +!example/accept_header.* example/*.pem test/httplib.cc test/httplib.h diff --git a/docker/main.cc b/docker/main.cc index 8ffbf2ca15..edc7093bda 100644 --- a/docker/main.cc +++ b/docker/main.cc @@ -48,32 +48,10 @@ std::string get_error_time_format() { return ss.str(); } -std::string get_client_ip(const Request &req) { - // Check for X-Forwarded-For header first (common in reverse proxy setups) - auto forwarded_for = req.get_header_value("X-Forwarded-For"); - if (!forwarded_for.empty()) { - // Get the first IP if there are multiple - auto comma_pos = forwarded_for.find(','); - if (comma_pos != std::string::npos) { - return forwarded_for.substr(0, comma_pos); - } - return forwarded_for; - } - - // Check for X-Real-IP header - auto real_ip = req.get_header_value("X-Real-IP"); - if (!real_ip.empty()) { return real_ip; } - - // Fallback to remote address (though cpp-httplib doesn't provide this - // directly) For demonstration, we'll use a placeholder - return "127.0.0.1"; -} - // NGINX Combined log format: // $remote_addr - $remote_user [$time_local] "$request" $status $body_bytes_sent // "$http_referer" "$http_user_agent" void nginx_access_logger(const Request &req, const Response &res) { - auto remote_addr = get_client_ip(req); std::string remote_user = "-"; // cpp-httplib doesn't have built-in auth user tracking auto time_local = get_time_format(); @@ -86,7 +64,7 @@ void nginx_access_logger(const Request &req, const Response &res) { if (http_user_agent.empty()) http_user_agent = "-"; std::cout << std::format("{} - {} [{}] \"{}\" {} {} \"{}\" \"{}\"", - remote_addr, remote_user, time_local, request, + req.remote_addr, remote_user, time_local, request, status, body_bytes_sent, http_referer, http_user_agent) << std::endl; @@ -100,7 +78,6 @@ void nginx_error_logger(const Error &err, const Request *req) { std::string level = "error"; if (req) { - auto client_ip = get_client_ip(*req); auto request = std::format("{} {} {}", req->method, req->path, req->version); auto host = req->get_header_value("Host"); @@ -108,8 +85,8 @@ void nginx_error_logger(const Error &err, const Request *req) { std::cerr << std::format("{} [{}] {}, client: {}, request: " "\"{}\", host: \"{}\"", - time_local, level, to_string(err), client_ip, - request, host) + time_local, level, to_string(err), + req->remote_addr, request, host) << std::endl; } else { // If no request context, just log the error @@ -131,6 +108,10 @@ void print_usage(const char *program_name) { std::cout << " Format: mount_point:document_root" << std::endl; std::cout << " (default: /:./html)" << std::endl; + std::cout << " --trusted-proxy Add trusted proxy IP address" + << std::endl; + std::cout << " (can be specified multiple times)" + << std::endl; std::cout << " --version Show version information" << std::endl; std::cout << " --help Show this help message" << std::endl; @@ -140,6 +121,9 @@ void print_usage(const char *program_name) { << " --host localhost --port 8080 --mount /:./html" << std::endl; std::cout << " " << program_name << " --host 0.0.0.0 --port 3000 --mount /api:./api" << std::endl; + std::cout << " " << program_name + << " --trusted-proxy 192.168.1.100 --trusted-proxy 10.0.0.1" + << std::endl; } struct ServerConfig { @@ -147,6 +131,7 @@ struct ServerConfig { int port = 8080; std::string mount_point = "/"; std::string document_root = "./html"; + std::vector trusted_proxies; }; enum class ParseResult { SUCCESS, HELP_REQUESTED, VERSION_REQUESTED, ERROR }; @@ -205,6 +190,14 @@ ParseResult parse_command_line(int argc, char *argv[], ServerConfig &config) { } else if (strcmp(argv[i], "--version") == 0) { std::cout << CPPHTTPLIB_VERSION << std::endl; return ParseResult::VERSION_REQUESTED; + } else if (strcmp(argv[i], "--trusted-proxy") == 0) { + if (i + 1 >= argc) { + std::cerr << "Error: --trusted-proxy requires an IP address argument" + << std::endl; + print_usage(argv[0]); + return ParseResult::ERROR; + } + config.trusted_proxies.push_back(argv[++i]); } else { std::cerr << "Error: Unknown option '" << argv[i] << "'" << std::endl; print_usage(argv[0]); @@ -218,6 +211,11 @@ bool setup_server(Server &svr, const ServerConfig &config) { svr.set_logger(nginx_access_logger); svr.set_error_logger(nginx_error_logger); + // Set trusted proxies if specified + if (!config.trusted_proxies.empty()) { + svr.set_trusted_proxies(config.trusted_proxies); + } + auto ret = svr.set_mount_point(config.mount_point, config.document_root); if (!ret) { std::cerr @@ -285,6 +283,16 @@ int main(int argc, char *argv[]) { << std::endl; std::cout << "Mount point: " << config.mount_point << " -> " << config.document_root << std::endl; + + if (!config.trusted_proxies.empty()) { + std::cout << "Trusted proxies: "; + for (size_t i = 0; i < config.trusted_proxies.size(); ++i) { + if (i > 0) std::cout << ", "; + std::cout << config.trusted_proxies[i]; + } + std::cout << std::endl; + } + std::cout << "Press Ctrl+C to shutdown gracefully..." << std::endl; auto ret = svr.listen(config.hostname, config.port); diff --git a/example/ssesvr.cc b/example/ssesvr.cc index 547b864f3d..6d7390716e 100644 --- a/example/ssesvr.cc +++ b/example/ssesvr.cc @@ -14,11 +14,18 @@ class EventDispatcher { public: EventDispatcher() {} - void wait_event(DataSink *sink) { + bool wait_event(DataSink *sink) { unique_lock lk(m_); int id = id_; - cv_.wait(lk, [&] { return cid_ == id; }); + + // Wait with timeout to prevent hanging if client disconnects + if (!cv_.wait_for(lk, std::chrono::seconds(5), + [&] { return cid_ == id; })) { + return false; // Timeout occurred + } + sink->write(message_.data(), message_.size()); + return true; } void send_event(const string &message) { @@ -71,8 +78,7 @@ int main(void) { cout << "connected to event1..." << endl; res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, DataSink &sink) { - ed.wait_event(&sink); - return true; + return ed.wait_event(&sink); }); }); @@ -80,8 +86,7 @@ int main(void) { cout << "connected to event2..." << endl; res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, DataSink &sink) { - ed.wait_event(&sink); - return true; + return ed.wait_event(&sink); }); }); diff --git a/httplib.h b/httplib.h index d171e56b6b..b76a17d07a 100644 --- a/httplib.h +++ b/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.26.0" -#define CPPHTTPLIB_VERSION_NUM "0x001A00" +#define CPPHTTPLIB_VERSION "0.27.0" +#define CPPHTTPLIB_VERSION_NUM "0x001B00" /* * Platform compatibility check @@ -1132,6 +1132,8 @@ class Server { Server & set_header_writer(std::function const &writer); + Server &set_trusted_proxies(const std::vector &proxies); + Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); @@ -1170,6 +1172,9 @@ class Server { const std::function &setup_request); std::atomic svr_sock_{INVALID_SOCKET}; + + std::vector trusted_proxies_; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; @@ -4600,13 +4605,35 @@ inline bool zstd_decompressor::decompress(const char *data, size_t data_length, } #endif +inline bool is_prohibited_header_name(const std::string &name) { + using udl::operator""_t; + + switch (str2tag(name)) { + case "REMOTE_ADDR"_t: + case "REMOTE_PORT"_t: + case "LOCAL_ADDR"_t: + case "LOCAL_PORT"_t: return true; + default: return false; + } +} + inline bool has_header(const Headers &headers, const std::string &key) { + if (is_prohibited_header_name(key)) { return false; } return headers.find(key) != headers.end(); } inline const char *get_header_value(const Headers &headers, const std::string &key, const char *def, size_t id) { + if (is_prohibited_header_name(key)) { +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "Prohibited header name '" + key + "' is specified."; + throw std::invalid_argument(msg); +#else + return ""; +#endif + } + auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); @@ -7501,6 +7528,12 @@ inline Server &Server::set_header_writer( return *this; } +inline Server & +Server::set_trusted_proxies(const std::vector &proxies) { + trusted_proxies_ = proxies; + return *this; +} + inline Server &Server::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; return *this; @@ -8289,6 +8322,40 @@ inline bool Server::dispatch_request_for_content_reader( return false; } +inline std::string +get_client_ip(const std::string &x_forwarded_for, + const std::vector &trusted_proxies) { + // X-Forwarded-For is a comma-separated list per RFC 7239 + std::vector ip_list; + detail::split(x_forwarded_for.data(), + x_forwarded_for.data() + x_forwarded_for.size(), ',', + [&](const char *b, const char *e) { + auto r = detail::trim(b, e, 0, static_cast(e - b)); + ip_list.emplace_back(std::string(b + r.first, b + r.second)); + }); + + for (size_t i = 0; i < ip_list.size(); ++i) { + auto ip = ip_list[i]; + + auto is_trusted_proxy = + std::any_of(trusted_proxies.begin(), trusted_proxies.end(), + [&](const std::string &proxy) { return ip == proxy; }); + + if (is_trusted_proxy) { + if (i == 0) { + // If the trusted proxy is the first IP, there's no preceding client IP + return ip; + } else { + // Return the IP immediately before the trusted proxy + return ip_list[i - 1]; + } + } + } + + // If no trusted proxy is found, return the first IP in the list + return ip_list.front(); +} + inline bool Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, @@ -8352,15 +8419,16 @@ Server::process_request(Stream &strm, const std::string &remote_addr, connection_closed = true; } - req.remote_addr = remote_addr; + if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) { + auto x_forwarded_for = req.get_header_value("X-Forwarded-For"); + req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_); + } else { + req.remote_addr = remote_addr; + } req.remote_port = remote_port; - req.set_header("REMOTE_ADDR", req.remote_addr); - req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); req.local_addr = local_addr; req.local_port = local_port; - req.set_header("LOCAL_ADDR", req.local_addr); - req.set_header("LOCAL_PORT", std::to_string(req.local_port)); if (req.has_header("Accept")) { const auto &accept_header = req.get_header_value("Accept"); @@ -11306,21 +11374,22 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { for (decltype(count) i = 0; i < count && !dsn_matched; i++) { auto val = sk_GENERAL_NAME_value(alt_names, i); - if (val->type == type) { - auto name = - reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); - auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); - - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || - !memcmp(&addr, name, addr_len)) { - ip_matched = true; - } - break; + if (!val || val->type != type) { continue; } + + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + if (name == nullptr) { continue; } + + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; } + break; } } diff --git a/test/test.cc b/test/test.cc index d6c06194cb..c9d6421719 100644 --- a/test/test.cc +++ b/test/test.cc @@ -11,6 +11,7 @@ #endif #include +#include #include #include #include @@ -78,6 +79,73 @@ static void read_file(const std::string &path, std::string &out) { fs.read(&out[0], static_cast(size)); } +void performance_test(const char *host) { + auto port = 1234; + + Server svr; + + svr.Get("/benchmark", [&](const Request & /*req*/, Response &res) { + res.set_content("Benchmark Response", "text/plain"); + }); + + auto listen_thread = std::thread([&]() { svr.listen(host, port); }); + auto se = detail::scope_exit([&] { + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(host, port); + + // Warm-up request to establish connection and resolve DNS + auto warmup_res = cli.Get("/benchmark"); + ASSERT_TRUE(warmup_res); // Ensure server is responding correctly + + // Run multiple trials and collect timings + const int num_trials = 20; + std::vector timings; + timings.reserve(num_trials); + + for (int i = 0; i < num_trials; i++) { + auto start = std::chrono::high_resolution_clock::now(); + auto res = cli.Get("/benchmark"); + auto end = std::chrono::high_resolution_clock::now(); + + auto elapsed = + std::chrono::duration_cast(end - start) + .count(); + + // Assertions after timing measurement to avoid overhead + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + timings.push_back(elapsed); + } + + // Calculate 25th percentile (lower quartile) + std::sort(timings.begin(), timings.end()); + auto p25 = timings[num_trials / 4]; + + // Format timings for output + std::ostringstream timings_str; + timings_str << "["; + for (size_t i = 0; i < timings.size(); i++) { + if (i > 0) timings_str << ", "; + timings_str << timings[i]; + } + timings_str << "]"; + + // Localhost HTTP GET should be fast even in CI environments + EXPECT_LE(p25, 5) << "25th percentile performance is too slow: " << p25 + << "ms (Issue #1777). Timings: " << timings_str.str(); +} + +TEST(BenchmarkTest, localhost) { performance_test("localhost"); } + +TEST(BenchmarkTest, v6) { performance_test("::1"); } + class UnixSocketTest : public ::testing::Test { protected: void TearDown() override { std::remove(pathname_.c_str()); } @@ -128,7 +196,7 @@ TEST_F(UnixSocketTest, PeerPid) { std::string remote_port_val; svr.Get(pattern_, [&](const httplib::Request &req, httplib::Response &res) { res.set_content(content_, "text/plain"); - remote_port_val = req.get_header_value("REMOTE_PORT"); + remote_port_val = std::to_string(req.remote_port); }); std::thread t{[&] { @@ -3032,21 +3100,20 @@ class ServerTest : public ::testing::Test { #endif .Get("/remote_addr", [&](const Request &req, Response &res) { - auto remote_addr = req.headers.find("REMOTE_ADDR")->second; - EXPECT_TRUE(req.has_header("REMOTE_PORT")); - EXPECT_EQ(req.remote_addr, req.get_header_value("REMOTE_ADDR")); - EXPECT_EQ(req.remote_port, - std::stoi(req.get_header_value("REMOTE_PORT"))); - res.set_content(remote_addr.c_str(), "text/plain"); + ASSERT_FALSE(req.has_header("REMOTE_ADDR")); + ASSERT_FALSE(req.has_header("REMOTE_PORT")); + ASSERT_ANY_THROW(req.get_header_value("REMOTE_ADDR")); + ASSERT_ANY_THROW(req.get_header_value("REMOTE_PORT")); + res.set_content(req.remote_addr, "text/plain"); }) .Get("/local_addr", [&](const Request &req, Response &res) { - EXPECT_TRUE(req.has_header("LOCAL_PORT")); - EXPECT_TRUE(req.has_header("LOCAL_ADDR")); - auto local_addr = req.get_header_value("LOCAL_ADDR"); - auto local_port = req.get_header_value("LOCAL_PORT"); - EXPECT_EQ(req.local_addr, local_addr); - EXPECT_EQ(req.local_port, std::stoi(local_port)); + ASSERT_FALSE(req.has_header("LOCAL_ADDR")); + ASSERT_FALSE(req.has_header("LOCAL_PORT")); + ASSERT_ANY_THROW(req.get_header_value("LOCAL_ADDR")); + ASSERT_ANY_THROW(req.get_header_value("LOCAL_PORT")); + auto local_addr = req.local_addr; + auto local_port = std::to_string(req.local_port); res.set_content(local_addr.append(":").append(local_port), "text/plain"); }) @@ -3634,46 +3701,6 @@ TEST_F(ServerTest, GetMethod200) { EXPECT_EQ("Hello World!", res->body); } -void performance_test(const char *host) { - auto port = 1234; - - Server svr; - - svr.Get("/benchmark", [&](const Request & /*req*/, Response &res) { - res.set_content("Benchmark Response", "text/plain"); - }); - - auto listen_thread = std::thread([&]() { svr.listen(host, port); }); - auto se = detail::scope_exit([&] { - svr.stop(); - listen_thread.join(); - ASSERT_FALSE(svr.is_running()); - }); - - svr.wait_until_ready(); - - Client cli(host, port); - - auto start = std::chrono::high_resolution_clock::now(); - - auto res = cli.Get("/benchmark"); - ASSERT_TRUE(res); - EXPECT_EQ(StatusCode::OK_200, res->status); - - auto end = std::chrono::high_resolution_clock::now(); - - auto elapsed = - std::chrono::duration_cast(end - start) - .count(); - - EXPECT_LE(elapsed, 5) << "Performance is too slow: " << elapsed - << "ms (Issue #1777)"; -} - -TEST(BenchmarkTest, localhost) { performance_test("localhost"); } - -TEST(BenchmarkTest, v6) { performance_test("::1"); } - TEST_F(ServerTest, GetEmptyFile) { auto res = cli_.Get("/empty_file"); ASSERT_TRUE(res); @@ -11043,11 +11070,18 @@ class EventDispatcher { public: EventDispatcher() {} - void wait_event(DataSink *sink) { + bool wait_event(DataSink *sink) { unique_lock lk(m_); int id = id_; - cv_.wait(lk, [&] { return cid_ == id; }); + + // Wait with timeout to prevent hanging if client disconnects + if (!cv_.wait_for(lk, std::chrono::seconds(5), + [&] { return cid_ == id; })) { + return false; // Timeout occurred + } + sink->write(message_.data(), message_.size()); + return true; } void send_event(const string &message) { @@ -11072,8 +11106,7 @@ TEST(ClientInThreadTest, Issue2068) { svr.Get("/event1", [&](const Request & /*req*/, Response &res) { res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, DataSink &sink) { - ed.wait_event(&sink); - return true; + return ed.wait_event(&sink); }); }); @@ -11116,9 +11149,11 @@ TEST(ClientInThreadTest, Issue2068) { std::this_thread::sleep_for(std::chrono::seconds(2)); stop = true; client->stop(); - client.reset(); t.join(); + + // Reset client after thread has finished + client.reset(); } } @@ -11172,3 +11207,240 @@ TEST(HeaderSmugglingTest, ChunkedTrailerHeadersMerged) { std::string res; ASSERT_TRUE(send_request(1, req, &res)); } + +TEST(ForwardedHeadersTest, NoProxiesSetting) { + Server svr; + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = cli.Get("/ip", {{"X-Forwarded-For", "203.0.113.66"}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, "203.0.113.66"); + EXPECT_TRUE(observed_remote_addr == "::1" || + observed_remote_addr == "127.0.0.1"); +} + +TEST(ForwardedHeadersTest, NoForwardedHeaders) { + Server svr; + + svr.set_trusted_proxies({"203.0.113.66"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = cli.Get("/ip"); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, ""); + EXPECT_TRUE(observed_remote_addr == "::1" || + observed_remote_addr == "127.0.0.1"); +} + +TEST(ForwardedHeadersTest, SingleTrustedProxy_UsesIPBeforeTrusted) { + Server svr; + + svr.set_trusted_proxies({"203.0.113.66"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = + cli.Get("/ip", {{"X-Forwarded-For", "198.51.100.23, 203.0.113.66"}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, "198.51.100.23, 203.0.113.66"); + EXPECT_EQ(observed_remote_addr, "198.51.100.23"); +} + +TEST(ForwardedHeadersTest, MultipleTrustedProxies_UsesClientIP) { + Server svr; + + svr.set_trusted_proxies({"203.0.113.66", "192.0.2.45"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = cli.Get( + "/ip", {{"X-Forwarded-For", "198.51.100.23, 203.0.113.66, 192.0.2.45"}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, "198.51.100.23, 203.0.113.66, 192.0.2.45"); + EXPECT_EQ(observed_remote_addr, "198.51.100.23"); +} + +TEST(ForwardedHeadersTest, TrustedProxyNotInHeader_UsesFirstFromXFF) { + Server svr; + + svr.set_trusted_proxies({"192.0.2.45"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = + cli.Get("/ip", {{"X-Forwarded-For", "198.51.100.23, 198.51.100.24"}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, "198.51.100.23, 198.51.100.24"); + EXPECT_EQ(observed_remote_addr, "198.51.100.23"); +} + +TEST(ForwardedHeadersTest, LastHopTrusted_SelectsImmediateLeftIP) { + Server svr; + + svr.set_trusted_proxies({"192.0.2.45"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = cli.Get( + "/ip", {{"X-Forwarded-For", "198.51.100.23, 203.0.113.66, 192.0.2.45"}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + EXPECT_EQ(observed_xff, "198.51.100.23, 203.0.113.66, 192.0.2.45"); + EXPECT_EQ(observed_remote_addr, "203.0.113.66"); +} + +TEST(ForwardedHeadersTest, HandlesWhitespaceAroundIPs) { + Server svr; + + svr.set_trusted_proxies({"192.0.2.45"}); + + std::string observed_remote_addr; + std::string observed_xff; + + svr.Get("/ip", [&](const Request &req, Response &res) { + observed_remote_addr = req.remote_addr; + observed_xff = req.get_header_value("X-Forwarded-For"); + res.set_content("ok", "text/plain"); + }); + + thread t = thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + auto res = cli.Get("/ip", {{"X-Forwarded-For", + " 198.51.100.23 , 203.0.113.66 , 192.0.2.45 "}}); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + + // Header parser trims surrounding whitespace of the header value + EXPECT_EQ(observed_xff, "198.51.100.23 , 203.0.113.66 , 192.0.2.45"); + EXPECT_EQ(observed_remote_addr, "203.0.113.66"); +}