diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 6e5907735c8a8..0d8c88ac14690 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -28,13 +28,17 @@ import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; +import jakarta.servlet.ReadListener; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; +import org.jetbrains.annotations.NotNull; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; @@ -86,6 +90,10 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; + if ("POST".equalsIgnoreCase(request.getMethod()) && "/v1/statement".equals(request.getRequestURI())) { + request = new MultiReadHttpServletRequestWrapper(request); + } + // skip authentication if non-secure or not configured if (!doesRequestSupportAuthentication(request)) { nextFilter.doFilter(request, response); @@ -259,4 +267,88 @@ public Enumeration getHeaders(String name) return super.getHeaders(name); } } + + public static class MultiReadHttpServletRequestWrapper + extends HttpServletRequestWrapper + { + private final byte[] cachedBody; + private final int contentLength; + private final ServletInputStream inputStream; + + public MultiReadHttpServletRequestWrapper(HttpServletRequest request) + throws IOException + { + super(request); + // Cache the body + contentLength = request.getContentLength(); + cachedBody = request.getInputStream().readNBytes(contentLength); + inputStream = new ServletInputStream() + { + ByteArrayInputStream bais = new ByteArrayInputStream(cachedBody); + + @Override + public int read() + throws IOException + { + return bais.read(); + } + + @Override + public int read(@NotNull byte[] b, int off, int len) + throws IOException + { + return bais.read(b, off, len); + } + + // Implement other methods as needed + @Override + public boolean isFinished() + { + return bais.available() == 0; + } + + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) + { + // Not implemented for this simple example + } + + @Override + public int available() + throws IOException + { + return bais.available(); + } + }; + } + + public byte[] getCachedBody() + { + return cachedBody; + } + + @Override + public ServletInputStream getInputStream() + { + return inputStream; + } + + @Override + public int getContentLength() + { + return contentLength; + } + + @Override + public long getContentLengthLong() + { + return contentLength; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 1cb04591b15da..7eb4eff6cde71 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -21,6 +21,8 @@ import com.facebook.airlift.discovery.client.ServiceSelectorManager; import com.facebook.airlift.discovery.client.testing.TestingDiscoveryModule; import com.facebook.airlift.event.client.EventModule; +import com.facebook.airlift.http.server.HttpServer; +import com.facebook.airlift.http.server.HttpServerInfo; import com.facebook.airlift.http.server.TheServlet; import com.facebook.airlift.http.server.testing.TestingHttpServer; import com.facebook.airlift.http.server.testing.TestingHttpServerModule; @@ -58,6 +60,7 @@ import com.facebook.presto.resourcemanager.ResourceManagerClusterStateProvider; import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.server.GracefulShutdownHandler; +import com.facebook.presto.server.HttpServerModule; import com.facebook.presto.server.PluginManager; import com.facebook.presto.server.ServerInfoResource; import com.facebook.presto.server.ServerMainModule; @@ -133,7 +136,6 @@ import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static com.google.inject.multibindings.Multibinder.newSetBinder; -import static java.lang.Integer.parseInt; import static java.nio.file.Files.createTempDirectory; import static java.nio.file.Files.isDirectory; import static java.util.Objects.requireNonNull; @@ -149,6 +151,8 @@ public class TestingPrestoServer private final PluginManager pluginManager; private final ConnectorManager connectorManager; private final TestingHttpServer server; + private final HttpServer coordServer; + private final HttpServerInfo serverInfo; private final CatalogManager catalogManager; private final TransactionManager transactionManager; private final SqlParser sqlParser; @@ -299,7 +303,6 @@ public TestingPrestoServer( ImmutableList.Builder modules = ImmutableList.builder() .add(new TestingNodeModule(Optional.ofNullable(environment))) - .add(new TestingHttpServerModule(parseInt(coordinator ? coordinatorPort : "0"))) .add(new JsonModule()) .add(installModuleIf( FeaturesConfig.class, @@ -334,6 +337,12 @@ public TestingPrestoServer( newSetBinder(binder, Filter.class, TheServlet.class).addBinding() .to(RequestBlocker.class).in(Scopes.SINGLETON); }); + if (coordinator) { + modules.add(new HttpServerModule()); + } + else { + modules.add(new TestingHttpServerModule(0)); + } if (discoveryUri != null) { requireNonNull(environment, "environment required when discoveryUri is present"); @@ -367,7 +376,17 @@ public TestingPrestoServer( connectorManager = injector.getInstance(ConnectorManager.class); - server = injector.getInstance(TestingHttpServer.class); + if (coordinator) { + coordServer = injector.getInstance(HttpServer.class); + server = null; + serverInfo = injector.getInstance(HttpServerInfo.class); + } + else { + coordServer = null; + server = injector.getInstance(TestingHttpServer.class); + serverInfo = null; + } + catalogManager = injector.getInstance(CatalogManager.class); transactionManager = injector.getInstance(TransactionManager.class); sqlParser = injector.getInstance(SqlParser.class); @@ -582,12 +601,22 @@ public Path getDataDirectory() public URI getBaseUrl() { - return server.getBaseUrl(); + if (coordinator) { + return serverInfo.getHttpUri(); + } + else { + return server.getBaseUrl(); + } } public URI resolve(String path) { - return server.getBaseUrl().resolve(path); + if (coordinator) { + return serverInfo.getHttpUri().resolve(path); + } + else { + return server.getBaseUrl().resolve(path); + } } public HostAndPort getAddress() @@ -597,7 +626,15 @@ public HostAndPort getAddress() public HostAndPort getHttpsAddress() { - URI httpsUri = server.getHttpServerInfo().getHttpsUri(); + + URI httpsUri; + if (coordinator) { + httpsUri = serverInfo.getHttpsUri(); + } + else { + httpsUri = server.getHttpServerInfo().getHttpsUri(); + } + return HostAndPort.fromParts(httpsUri.getHost(), httpsUri.getPort()); }