diff --git a/core/pom.xml b/core/pom.xml
index 504a409027..1878d5bbff 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -257,5 +257,10 @@
mockito-inline
test
+
+ com.squareup.okhttp3
+ okhttp
+ test
+
diff --git a/core/src/main/java/io/confluent/rest/Application.java b/core/src/main/java/io/confluent/rest/Application.java
index 1ef2c75b9b..10a36165e9 100644
--- a/core/src/main/java/io/confluent/rest/Application.java
+++ b/core/src/main/java/io/confluent/rest/Application.java
@@ -28,6 +28,7 @@
import io.confluent.rest.exceptions.JsonMappingExceptionMapper;
import io.confluent.rest.exceptions.JsonParseExceptionMapper;
import io.confluent.rest.extension.ResourceExtension;
+import io.confluent.rest.filters.ConnectionDurationFilter;
import io.confluent.rest.filters.CsrfTokenProtectionFilter;
import io.confluent.rest.handlers.SniHandler;
import io.confluent.rest.metrics.Jetty429MetricsDosFilterListener;
@@ -402,6 +403,8 @@ public Handler configureHandler() {
configureDosFilters(context);
+ configureConnectionDurationFilter(context);
+
configurePreResourceHandling(context);
context.addFilter(servletHolder, "/*", null);
configurePostResourceHandling(context);
@@ -751,6 +754,16 @@ private void configureGlobalDosFilter(ServletContextHandler context) {
context.addFilter(filterHolder, "/*", EnumSet.of(DispatcherType.REQUEST));
}
+ private void configureConnectionDurationFilter(ServletContextHandler context) {
+ if (config.getMaxConnectionDuration() > 0) {
+ FilterHolder filterHolder = new FilterHolder(ConnectionDurationFilter.class);
+ filterHolder.setName("connection-duration-filter");
+ filterHolder.setInitParameter(RestConfig.MAX_CONNECTION_DURATION_MS,
+ String.valueOf(config.getMaxConnectionDuration()));
+ context.addFilter(filterHolder, "/*", EnumSet.of(DispatcherType.REQUEST));
+ }
+ }
+
private FilterHolder configureDosFilter(DoSFilter dosFilter, String rate) {
FilterHolder filterHolder = new FilterHolder(dosFilter);
diff --git a/core/src/main/java/io/confluent/rest/RestConfig.java b/core/src/main/java/io/confluent/rest/RestConfig.java
index 6de18bd754..b9c661d5bc 100644
--- a/core/src/main/java/io/confluent/rest/RestConfig.java
+++ b/core/src/main/java/io/confluent/rest/RestConfig.java
@@ -478,6 +478,13 @@ public class RestConfig extends AbstractConfig {
+ "If the limit is set to a non-positive number, no limit is applied. Default is 0.";
private static final int SERVER_CONNECTION_LIMIT_DEFAULT = 0;
+ public static final String MAX_CONNECTION_DURATION_MS = "max.connection.duration.ms";
+ public static final String MAX_CONNECTION_DURATION_MS_DOC =
+ "The maximum duration in milliseconds that a connection can be open. "
+ + "If a connection is open for longer than this duration, it will be closed. "
+ + "If set to 0, no limit is applied. Default is 0.";
+ protected static final long MAX_CONNECTION_DURATION_MS_DEFAULT = 0;
+
// For rest-utils applications connectors correspond to configured listeners. See
// ApplicationServer#parseListeners for more details.
private static final String CONNECTOR_CONNECTION_LIMIT = "connector.connection.limit";
@@ -1050,6 +1057,11 @@ private static ConfigDef incompleteBaseConfigDef() {
SERVER_CONNECTION_LIMIT_DEFAULT,
Importance.LOW,
SERVER_CONNECTION_LIMIT_DOC
+ ).define(MAX_CONNECTION_DURATION_MS,
+ Type.LONG,
+ MAX_CONNECTION_DURATION_MS_DEFAULT,
+ Importance.LOW,
+ MAX_CONNECTION_DURATION_MS_DOC
).define(
CONNECTOR_CONNECTION_LIMIT,
Type.INT,
@@ -1258,6 +1270,10 @@ public final int getServerConnectionLimit() {
return getInt(SERVER_CONNECTION_LIMIT);
}
+ public final long getMaxConnectionDuration() {
+ return getLong(MAX_CONNECTION_DURATION_MS);
+ }
+
public final int getConnectorConnectionLimit() {
return getInt(CONNECTOR_CONNECTION_LIMIT);
}
diff --git a/core/src/main/java/io/confluent/rest/filters/ConnectionDurationFilter.java b/core/src/main/java/io/confluent/rest/filters/ConnectionDurationFilter.java
new file mode 100644
index 0000000000..b242d8da0a
--- /dev/null
+++ b/core/src/main/java/io/confluent/rest/filters/ConnectionDurationFilter.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2024 Confluent Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.confluent.rest.filters;
+
+import io.confluent.rest.RestConfig;
+import java.io.IOException;
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletResponse;
+import org.eclipse.jetty.server.HttpChannel;
+import org.eclipse.jetty.server.Response;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Simple server-side request filter that limits the connection age.
+ *
+ *
+ *
Long-running client connections can be problematic for multiple server instances behind
+ * an NLB due to uneven load distribution (especially in case of HTTP/2.0 that specifically
+ * encourages long-lived connections). This filter closes any connection that receives a request
+ * after the connection has been open for {@link RestConfig#MAX_CONNECTION_DURATION_MS} ms.
+ */
+public class ConnectionDurationFilter implements Filter {
+
+ private static final Logger log = LoggerFactory.getLogger(ConnectionDurationFilter.class);
+ private static long MAX_CONNECTION_DURATION_MS = -1;
+
+ @Override
+ public void init(FilterConfig filterConfig) throws ServletException {
+ String maxConnectionDuration = filterConfig.getInitParameter(
+ RestConfig.MAX_CONNECTION_DURATION_MS);
+ if (maxConnectionDuration != null) {
+ MAX_CONNECTION_DURATION_MS = Long.parseLong(maxConnectionDuration);
+ }
+ }
+
+ @Override
+ public void doFilter(
+ ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+ throws IOException, ServletException {
+
+ HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
+
+ HttpChannel channel = ((Response) httpServletResponse).getHttpChannel();
+
+ long connectionCreationTime = channel.getConnection().getCreatedTimeStamp();
+ long connectionAge = System.currentTimeMillis() - connectionCreationTime;
+
+ if (connectionAge > MAX_CONNECTION_DURATION_MS) {
+ log.debug("Connection from remote peer {} has been active for {}ms. Closing the connection.",
+ channel.getRemoteAddress(), connectionAge);
+ channel.getEndPoint().close();
+ } else {
+ log.trace("Connection from remote peer {} is {}ms old. Leaving the connection as is",
+ channel.getRemoteAddress(), connectionAge);
+ }
+
+ filterChain.doFilter(servletRequest, servletResponse);
+
+ }
+
+ @Override
+ public void destroy() {
+
+ }
+
+}
diff --git a/core/src/test/java/io/confluent/rest/ConnectionDurationFilterTest.java b/core/src/test/java/io/confluent/rest/ConnectionDurationFilterTest.java
new file mode 100644
index 0000000000..d92f164e89
--- /dev/null
+++ b/core/src/test/java/io/confluent/rest/ConnectionDurationFilterTest.java
@@ -0,0 +1,225 @@
+/*
+ * Copyright 2024 Confluent Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.confluent.rest;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.google.common.collect.ImmutableMap;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import javax.servlet.http.HttpServletRequest;
+import javax.ws.rs.GET;
+import javax.ws.rs.Path;
+import javax.ws.rs.core.Configurable;
+import javax.ws.rs.core.Context;
+import javax.ws.rs.core.Response.Status;
+import javax.ws.rs.core.UriBuilder;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.Response;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.impl.conn.BasicHttpClientConnectionManager;
+import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
+import org.apache.http.util.EntityUtils;
+import org.eclipse.jetty.server.Server;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+
+/*
+ * Note on client support --
+ * The Apache HttpClient library v4.x does not support HTTP/2.0, so we use
+ * OkHttp for HTTP/2.0 tests.
+ */
+public class ConnectionDurationFilterTest {
+
+ static Server server;
+ private static final long MAX_CONNECTION_DURATION_MS = 100;
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ MyAddressApplication application =
+ new MyAddressApplication(
+ new MyAddressConfig(
+ ImmutableMap.of(
+ "listeners", "http://localhost:8080",
+ "http2.enabled", "true",
+ "max.connection.duration.ms", String.valueOf(MAX_CONNECTION_DURATION_MS)))
+ );
+ server = application.createServer();
+ server.start();
+ }
+
+ @AfterAll
+ public static void tearDown() throws Exception {
+ server.stop();
+ }
+
+ @Test
+ public void connectionDurationTest_ApacheHttpClient_Ephemeral_HTTP1() throws Exception {
+
+ CloseableHttpClient client = createPersistentClient();
+ List addressList = new ArrayList<>();
+
+ // Send a few requests with a delay between consecutive requests that is longer
+ // than the connection duration threshold.
+ // Each request should succeed and the server should return a different address for each request.
+ for (int i=0; i<5; i++) {
+
+ HttpGet request = createRequest(server.getURI());
+ assertEquals("HTTP/1.1", request.getProtocolVersion().toString());
+
+ CloseableHttpResponse response = client.execute(request);
+ assertEquals(Status.OK.getStatusCode(), response.getStatusLine().getStatusCode());
+
+ String responseBody = EntityUtils.toString(response.getEntity());
+ assertTrue(responseBody.startsWith("127.0.0.1:"));
+
+ response.close();
+ addressList.add(responseBody);
+ Thread.sleep(MAX_CONNECTION_DURATION_MS + 10);
+ }
+ client.close();
+
+ // check that we actually used 5 different addresses under the hood
+ assertEquals(5, addressList.stream().distinct().count());
+ }
+
+ @Test
+ public void connectionDurationTest_ApacheHttpClient_Persistent_HTTP1() throws Exception {
+
+ CloseableHttpClient client = createEphemeralClient();
+ List addressList = new ArrayList<>();
+
+ // Send a few requests with a delay between consecutive requests that is longer
+ // than the connection duration threshold.
+ // Each request should succeed and the server should return a different address for each request.
+ for (int i=0; i<5; i++) {
+
+ HttpGet request = createRequest(server.getURI());
+ assertEquals("HTTP/1.1", request.getProtocolVersion().toString());
+
+ CloseableHttpResponse response = client.execute(request);
+ assertEquals(Status.OK.getStatusCode(), response.getStatusLine().getStatusCode());
+
+ String responseBody = EntityUtils.toString(response.getEntity());
+ assertTrue(responseBody.startsWith("127.0.0.1:"));
+
+ response.close();
+ addressList.add(responseBody);
+ Thread.sleep(MAX_CONNECTION_DURATION_MS + 10);
+ }
+ client.close();
+
+ // check that we actually used 5 different addresses under the hood
+ assertEquals(5, addressList.stream().distinct().count());
+ }
+
+ @Test
+ public void connectionDurationTest_OkHttpClient_HTTP2() throws Exception {
+
+ // explicitly set HTTP 2.0 protocol
+ OkHttpClient client = new OkHttpClient.Builder()
+ .protocols(Collections.singletonList(okhttp3.Protocol.H2_PRIOR_KNOWLEDGE))
+ .build();
+
+ List addressList = new ArrayList<>();
+
+ // Send a few requests with a delay between consecutive requests that is longer
+ // than the connection duration threshold.
+ // Each request should succeed and the server should return a different address for each request.
+ for (int i=0; i<5; i++) {
+ Request request = new Request.Builder()
+ .url("http://localhost:8080/whatsmyaddress")
+ .build();
+
+ Response response = client.newCall(request).execute();
+ // assert that the request was indeed http 2
+ assertEquals("h2_prior_knowledge", response.protocol().toString());
+
+ assertEquals(Status.OK.getStatusCode(), response.code());
+
+ String responseBody = response.body().string();
+ assertTrue(responseBody.startsWith("127.0.0.1:"));
+
+ response.close();
+ addressList.add(responseBody);
+ Thread.sleep(MAX_CONNECTION_DURATION_MS + 10);
+ }
+
+ // shutdown the client
+ client.dispatcher().executorService().shutdown();
+
+ // check that we actually used 5 different addresses under the hood
+ assertEquals(5, addressList.stream().distinct().count());
+
+ }
+
+ private static HttpGet createRequest(URI serverUri) {
+ return new HttpGet(UriBuilder.fromUri(serverUri).path("/whatsmyaddress").build());
+ }
+
+ private static CloseableHttpClient createEphemeralClient() {
+ return HttpClients.custom()
+ .setConnectionManager(new BasicHttpClientConnectionManager())
+ .setKeepAliveStrategy((httpResponse, httpContext) -> -1)
+ .build();
+ }
+
+ private static CloseableHttpClient createPersistentClient() {
+ return HttpClients.custom()
+ .setConnectionManager(new PoolingHttpClientConnectionManager())
+ .setKeepAliveStrategy((httpResponse, httpContext) -> 5000)
+ .build();
+ }
+
+ public static final class MyAddressApplication extends Application {
+
+ public MyAddressApplication(MyAddressConfig config) {
+ super(config);
+ }
+
+ @Override
+ public void setupResources(Configurable> config, MyAddressConfig appConfig) {
+ config.register(AddressResource.class);
+ }
+ }
+
+ public static final class MyAddressConfig extends RestConfig {
+
+ public MyAddressConfig(Map configs) {
+ super(baseConfigDef(), configs);
+ }
+ }
+
+ @Path("/whatsmyaddress")
+ public static final class AddressResource {
+
+ @GET
+ public String getAddress(@Context HttpServletRequest request) {
+ return request.getRemoteAddr() + ":" + request.getRemotePort();
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index d79e7eacaa..82fb13a66a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -59,6 +59,7 @@
7.8.0-0
2.5.2
9.4.53.v20231009
+ 4.9.2
@@ -297,6 +298,12 @@
${mockito.version}
test
+
+ com.squareup.okhttp3
+ okhttp
+ ${okhttp.version}
+ test
+