diff --git a/httplib.h b/httplib.h index b76a17d07a..8411c8c93b 100644 --- a/httplib.h +++ b/httplib.h @@ -7692,7 +7692,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -8403,8 +8404,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); diff --git a/test/test.cc b/test/test.cc index c9d6421719..b54c0f0afb 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4289,10 +4289,21 @@ TEST_F(ServerTest, TooLongRequest) { } request += "_NG"; + auto start = std::chrono::high_resolution_clock::now(); + + cli_.set_keep_alive(true); auto res = cli_.Get(request.c_str()); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast(end - start) + .count(); + ASSERT_TRUE(res); EXPECT_EQ(StatusCode::UriTooLong_414, res->status); + EXPECT_LE(elapsed, 100); + EXPECT_EQ("close", res->get_header_value("Connection")); + EXPECT_FALSE(cli_.is_socket_open()); } TEST_F(ServerTest, AlmostTooLongRequest) { @@ -4363,10 +4374,21 @@ TEST_F(ServerTest, LongHeader) { } TEST_F(ServerTest, LongQueryValue) { + auto start = std::chrono::high_resolution_clock::now(); + + cli_.set_keep_alive(true); auto res = cli_.Get(LONG_QUERY_URL.c_str()); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast(end - start) + .count(); + ASSERT_TRUE(res); EXPECT_EQ(StatusCode::UriTooLong_414, res->status); + EXPECT_LE(elapsed, 100); + EXPECT_EQ("close", res->get_header_value("Connection")); + EXPECT_FALSE(cli_.is_socket_open()); } TEST_F(ServerTest, TooLongQueryValue) { @@ -4460,6 +4482,7 @@ TEST_F(ServerTest, HeaderCountExceedsLimit) { } // This should fail due to exceeding header count limit + cli_.set_keep_alive(true); auto res = cli_.Get("/hi", headers); // The request should either fail or return 400 Bad Request @@ -4470,6 +4493,9 @@ TEST_F(ServerTest, HeaderCountExceedsLimit) { // Or the request should fail entirely EXPECT_FALSE(res); } + + EXPECT_EQ("close", res->get_header_value("Connection")); + EXPECT_FALSE(cli_.is_socket_open()); } TEST_F(ServerTest, PercentEncoding) { @@ -4524,6 +4550,7 @@ TEST_F(ServerTest, HeaderCountSecurityTest) { } // Try to POST with excessive headers + cli_.set_keep_alive(true); auto res = cli_.Post("/", attack_headers, "test_data", "text/plain"); // Should either fail or return 400 Bad Request due to security limit @@ -4534,6 +4561,9 @@ TEST_F(ServerTest, HeaderCountSecurityTest) { // Request failed, which is the expected behavior for DoS protection EXPECT_FALSE(res); } + + EXPECT_EQ("close", res->get_header_value("Connection")); + EXPECT_FALSE(cli_.is_socket_open()); } TEST_F(ServerTest, MultipartFormData) { @@ -5854,6 +5884,20 @@ TEST_F(ServerTest, TooManyRedirect) { EXPECT_EQ(Error::ExceedRedirectCount, res.error()); } +TEST_F(ServerTest, BadRequestLineCancelsKeepAlive) { + Request req; + req.method = "FOOBAR"; + req.path = "/hi"; + + cli_.set_keep_alive(true); + auto res = cli_.send(req); + + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::BadRequest_400, res->status); + EXPECT_EQ("close", res->get_header_value("Connection")); + EXPECT_FALSE(cli_.is_socket_open()); +} + TEST_F(ServerTest, StartTime) { auto res = cli_.Get("/test-start-time"); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT