Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import org.jetbrains.annotations.NotNull;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
Expand Down Expand Up @@ -86,6 +90,10 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;

if ("POST".equalsIgnoreCase(request.getMethod()) && "/v1/statement".equals(request.getRequestURI())) {
request = new MultiReadHttpServletRequestWrapper(request);
}

// skip authentication if non-secure or not configured
if (!doesRequestSupportAuthentication(request)) {
nextFilter.doFilter(request, response);
Expand Down Expand Up @@ -259,4 +267,88 @@ public Enumeration<String> getHeaders(String name)
return super.getHeaders(name);
}
}

public static class MultiReadHttpServletRequestWrapper
extends HttpServletRequestWrapper
{
private final byte[] cachedBody;
private final int contentLength;
private final ServletInputStream inputStream;

public MultiReadHttpServletRequestWrapper(HttpServletRequest request)
throws IOException
{
super(request);
// Cache the body
contentLength = request.getContentLength();
cachedBody = request.getInputStream().readNBytes(contentLength);
inputStream = new ServletInputStream()
{
ByteArrayInputStream bais = new ByteArrayInputStream(cachedBody);

@Override
public int read()
throws IOException
{
return bais.read();
}

@Override
public int read(@NotNull byte[] b, int off, int len)
throws IOException
{
return bais.read(b, off, len);
}

// Implement other methods as needed
@Override
public boolean isFinished()
{
return bais.available() == 0;
}

@Override
public boolean isReady()
{
return true;
}

@Override
public void setReadListener(ReadListener readListener)
{
// Not implemented for this simple example
}

@Override
public int available()
throws IOException
{
return bais.available();
}
};
}

public byte[] getCachedBody()
{
return cachedBody;
}

@Override
public ServletInputStream getInputStream()
{
return inputStream;
}

@Override
public int getContentLength()
{
return contentLength;
}

@Override
public long getContentLengthLong()
{
return contentLength;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import com.facebook.airlift.discovery.client.ServiceSelectorManager;
import com.facebook.airlift.discovery.client.testing.TestingDiscoveryModule;
import com.facebook.airlift.event.client.EventModule;
import com.facebook.airlift.http.server.HttpServer;
import com.facebook.airlift.http.server.HttpServerInfo;
import com.facebook.airlift.http.server.TheServlet;
import com.facebook.airlift.http.server.testing.TestingHttpServer;
import com.facebook.airlift.http.server.testing.TestingHttpServerModule;
Expand Down Expand Up @@ -58,6 +60,7 @@
import com.facebook.presto.resourcemanager.ResourceManagerClusterStateProvider;
import com.facebook.presto.security.AccessControlManager;
import com.facebook.presto.server.GracefulShutdownHandler;
import com.facebook.presto.server.HttpServerModule;
import com.facebook.presto.server.PluginManager;
import com.facebook.presto.server.ServerInfoResource;
import com.facebook.presto.server.ServerMainModule;
Expand Down Expand Up @@ -133,7 +136,6 @@
import static com.google.common.io.MoreFiles.deleteRecursively;
import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static java.lang.Integer.parseInt;
import static java.nio.file.Files.createTempDirectory;
import static java.nio.file.Files.isDirectory;
import static java.util.Objects.requireNonNull;
Expand All @@ -149,6 +151,8 @@ public class TestingPrestoServer
private final PluginManager pluginManager;
private final ConnectorManager connectorManager;
private final TestingHttpServer server;
private final HttpServer coordServer;
private final HttpServerInfo serverInfo;
private final CatalogManager catalogManager;
private final TransactionManager transactionManager;
private final SqlParser sqlParser;
Expand Down Expand Up @@ -299,7 +303,6 @@ public TestingPrestoServer(

ImmutableList.Builder<Module> modules = ImmutableList.<Module>builder()
.add(new TestingNodeModule(Optional.ofNullable(environment)))
.add(new TestingHttpServerModule(parseInt(coordinator ? coordinatorPort : "0")))
.add(new JsonModule())
.add(installModuleIf(
FeaturesConfig.class,
Expand Down Expand Up @@ -334,6 +337,12 @@ public TestingPrestoServer(
newSetBinder(binder, Filter.class, TheServlet.class).addBinding()
.to(RequestBlocker.class).in(Scopes.SINGLETON);
});
if (coordinator) {
modules.add(new HttpServerModule());
}
else {
modules.add(new TestingHttpServerModule(0));
}

if (discoveryUri != null) {
requireNonNull(environment, "environment required when discoveryUri is present");
Expand Down Expand Up @@ -367,7 +376,17 @@ public TestingPrestoServer(

connectorManager = injector.getInstance(ConnectorManager.class);

server = injector.getInstance(TestingHttpServer.class);
if (coordinator) {
coordServer = injector.getInstance(HttpServer.class);
server = null;
serverInfo = injector.getInstance(HttpServerInfo.class);
}
else {
coordServer = null;
server = injector.getInstance(TestingHttpServer.class);
serverInfo = null;
}

catalogManager = injector.getInstance(CatalogManager.class);
transactionManager = injector.getInstance(TransactionManager.class);
sqlParser = injector.getInstance(SqlParser.class);
Expand Down Expand Up @@ -582,12 +601,22 @@ public Path getDataDirectory()

public URI getBaseUrl()
{
return server.getBaseUrl();
if (coordinator) {
return serverInfo.getHttpUri();
}
else {
return server.getBaseUrl();
}
}

public URI resolve(String path)
{
return server.getBaseUrl().resolve(path);
if (coordinator) {
return serverInfo.getHttpUri().resolve(path);
}
else {
return server.getBaseUrl().resolve(path);
}
}

public HostAndPort getAddress()
Expand All @@ -597,7 +626,15 @@ public HostAndPort getAddress()

public HostAndPort getHttpsAddress()
{
URI httpsUri = server.getHttpServerInfo().getHttpsUri();

URI httpsUri;
if (coordinator) {
httpsUri = serverInfo.getHttpsUri();
}
else {
httpsUri = server.getHttpServerInfo().getHttpsUri();
}

return HostAndPort.fromParts(httpsUri.getHost(), httpsUri.getPort());
}

Expand Down
Loading