Skip to content

Commit c677a20

Browse files
committed
Moved execution of servlet to the filter chain to allow filters to halt execution. The servlet is injected in the chain by the FilterChainManager. This should fix the last issues with #65 and fix #66
1 parent 0433b6b commit c677a20

File tree

12 files changed

+279
-42
lines changed

12 files changed

+279
-42
lines changed

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.slf4j.Logger;
2323
import org.slf4j.LoggerFactory;
2424

25+
import javax.servlet.FilterChain;
2526
import javax.servlet.Servlet;
2627
import javax.servlet.ServletContext;
2728
import javax.servlet.ServletException;
@@ -168,6 +169,10 @@ protected void setServletContext(final ServletContext context) {
168169
filterChainManager = new AwsFilterChainManager((AwsServletContext)context);
169170
}
170171

172+
protected FilterChain getFilterChain(ContainerRequestType req, Servlet servlet) {
173+
return filterChainManager.getFilterChain(req, servlet);
174+
}
175+
171176

172177
//-------------------------------------------------------------
173178
// Methods - Protected
@@ -182,12 +187,11 @@ protected void setServletContext(final ServletContext context) {
182187
* @throws ServletException
183188
*/
184189
protected void doFilter(ContainerRequestType request, ContainerResponseType response, Servlet servlet) throws IOException, ServletException {
185-
FilterChainHolder chain = filterChainManager.getFilterChain(request, servlet);
190+
FilterChain chain = getFilterChain(request, servlet);
186191
log.debug("FilterChainHolder.doFilter {}", chain);
187192
chain.doFilter(request, response);
188193
}
189194

190-
191195
//-------------------------------------------------------------
192196
// Inner Class -
193197
//-------------------------------------------------------------

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
* during a request lifecycle
2929
*/
3030
public class FilterChainHolder implements FilterChain {
31-
private final Servlet servlet;
3231

3332
//-------------------------------------------------------------
3433
// Variables - Private
@@ -46,22 +45,19 @@ public class FilterChainHolder implements FilterChain {
4645

4746
/**
4847
* Creates a new empty <code>FilterChainHolder</code>
49-
* @param servlet
5048
*/
51-
FilterChainHolder(Servlet servlet) {
52-
this(new ArrayList<>(), servlet);
49+
FilterChainHolder() {
50+
this(new ArrayList<>());
5351
}
5452

5553

5654
/**
5755
* Creates a new instance of a filter chain holder
5856
* @param allFilters A populated list of <code>FilterHolder</code> objects
59-
* @param servlet
6057
*/
61-
FilterChainHolder(List<FilterHolder> allFilters, Servlet servlet) {
58+
FilterChainHolder(List<FilterHolder> allFilters) {
6259
filters = allFilters;
6360
resetHolder();
64-
this.servlet = servlet;
6561
}
6662

6763

@@ -72,34 +68,27 @@ public class FilterChainHolder implements FilterChain {
7268
@Override
7369
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException, ServletException {
7470
currentFilter++;
75-
if (filters == null || filters.size() == 0 ) {
76-
log.debug("Could not find filters to execute, returning");
77-
return;
78-
} else if (currentFilter > filters.size() - 1) {
79-
if (null != servlet) {
80-
log.debug("Starting servlet {}", servlet);
81-
servlet.service(servletRequest, servletResponse);
82-
log.debug("Executed servlet {}", servlet);
83-
return;
84-
} else {
85-
log.debug("No more filters");
86-
return;
87-
}
88-
}
8971
// TODO: We do not check for async filters here
9072

91-
FilterHolder holder = filters.get(currentFilter);
73+
// if we still have filters, keep running through the chain
74+
if (currentFilter <= filters.size() - 1) {
75+
FilterHolder holder = filters.get(currentFilter);
76+
77+
// lazily initialize filters when they are needed
78+
if (!holder.isFilterInitialized()) {
79+
holder.init();
80+
}
81+
log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
82+
currentFilter, holder.getFilterName(), holder.getFilter());
83+
holder.getFilter().doFilter(servletRequest, servletResponse, this);
84+
log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
85+
currentFilter, holder.getFilterName(), holder.getFilter());
86+
}
9287

93-
// lazily initialize filters when they are needed
94-
if (!holder.isFilterInitialized()) {
95-
holder.init();
88+
// if for some reason the response wasn't flushed yet, we force it here.
89+
if (!servletResponse.isCommitted()) {
90+
servletResponse.flushBuffer();
9691
}
97-
log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
98-
currentFilter, holder.getFilterName(), holder.getFilter());
99-
holder.getFilter().doFilter(servletRequest, servletResponse, this);
100-
log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
101-
currentFilter, holder.getFilterName(), holder.getFilter());
102-
currentFilter--;
10392
}
10493

10594

@@ -162,6 +151,6 @@ private void resetHolder() {
162151

163152
@Override
164153
public String toString() {
165-
return "filters=" + filters + ", servlet=" + servlet;
154+
return "filters=" + filters;
166155
}
167156
}

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@
1313
package com.amazonaws.serverless.proxy.internal.servlet;
1414

1515
import javax.servlet.DispatcherType;
16+
import javax.servlet.Filter;
17+
import javax.servlet.FilterChain;
18+
import javax.servlet.FilterConfig;
1619
import javax.servlet.Servlet;
1720
import javax.servlet.ServletContext;
21+
import javax.servlet.ServletException;
22+
import javax.servlet.ServletRequest;
23+
import javax.servlet.ServletResponse;
1824
import javax.servlet.http.HttpServletRequest;
1925

26+
import java.io.IOException;
2027
import java.util.Collections;
2128
import java.util.HashMap;
2229
import java.util.List;
@@ -94,10 +101,13 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
94101
return getFilterChainCache(type, targetPath, servlet);
95102
}
96103

97-
FilterChainHolder chainHolder = new FilterChainHolder(servlet);
104+
FilterChainHolder chainHolder = new FilterChainHolder();
98105

99106
Map<String, FilterHolder> registrations = getFilterHolders();
100107
if (registrations == null || registrations.size() == 0) {
108+
if (servlet != null) {
109+
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext));
110+
}
101111
return chainHolder;
102112
}
103113
for (String name : registrations.keySet()) {
@@ -117,6 +127,10 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
117127
// we assume we only ever have one servlet.
118128
}
119129

130+
if (servlet != null) {
131+
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext));
132+
}
133+
120134
putFilterChainCache(type, targetPath, chainHolder);
121135
// update total filter size
122136
if (filtersSize != registrations.size()) {
@@ -148,7 +162,7 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S
148162
return null;
149163
}
150164

151-
return new FilterChainHolder(filterCache.get(key), servlet);
165+
return new FilterChainHolder(filterCache.get(key));
152166
}
153167

154168

@@ -303,4 +317,34 @@ void setDispatcherType(DispatcherType dispatcherType) {
303317
this.dispatcherType = dispatcherType;
304318
}
305319
}
320+
321+
private class ServletExecutionFilter implements Filter {
322+
323+
private FilterConfig config;
324+
private Servlet handlerServlet;
325+
326+
public ServletExecutionFilter(Servlet handler) {
327+
handlerServlet = handler;
328+
}
329+
330+
@Override
331+
public void init(FilterConfig filterConfig)
332+
throws ServletException {
333+
config = filterConfig;
334+
}
335+
336+
337+
@Override
338+
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
339+
throws IOException, ServletException {
340+
handlerServlet.service(servletRequest, servletResponse);
341+
filterChain.doFilter(servletRequest, servletResponse);
342+
}
343+
344+
345+
@Override
346+
public void destroy() {
347+
348+
}
349+
}
306350
}

aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@
2929
import spark.embeddedserver.EmbeddedServerFactory;
3030
import spark.embeddedserver.EmbeddedServers;
3131

32+
import javax.servlet.DispatcherType;
33+
import javax.servlet.FilterRegistration;
34+
3235
import java.lang.reflect.Field;
3336
import java.lang.reflect.InvocationTargetException;
3437
import java.lang.reflect.Method;
38+
import java.util.EnumSet;
3539
import java.util.concurrent.CountDownLatch;
3640

3741
/**
@@ -162,10 +166,12 @@ protected void handleRequest(AwsProxyHttpServletRequest httpServletRequest, AwsH
162166
if (startupHandler != null) {
163167
startupHandler.onStartup(getServletContext());
164168
}
169+
170+
// manually add the spark filter to the chain. This should the last one and match all uris
171+
FilterRegistration.Dynamic sparkRegistration = getServletContext().addFilter("SparkFilter", embeddedServer.getSparkFilter());
172+
sparkRegistration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/*");
165173
}
166174

167175
doFilter(httpServletRequest, httpServletResponse, null);
168-
169-
embeddedServer.handle(httpServletRequest, httpServletResponse);
170176
}
171177
}

aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/embeddedserver/LambdaEmbeddedServer.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import spark.ssl.SslStores;
1010
import spark.staticfiles.StaticFilesConfiguration;
1111

12+
import javax.servlet.Filter;
1213
import javax.servlet.ServletException;
1314
import javax.servlet.http.HttpServletRequest;
1415
import javax.servlet.http.HttpServletResponse;
@@ -95,4 +96,13 @@ public void handle(HttpServletRequest request, HttpServletResponse response)
9596
throws IOException, ServletException {
9697
sparkFilter.doFilter(request, response, null);
9798
}
99+
100+
101+
/**
102+
* Returns the initialized instance of the main Spark filter.
103+
* @return The spark filter instance.
104+
*/
105+
public Filter getSparkFilter() {
106+
return sparkFilter;
107+
}
98108
}

aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandlerTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
88
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
99
import com.amazonaws.serverless.proxy.spark.filter.CustomHeaderFilter;
10+
import com.amazonaws.serverless.proxy.spark.filter.UnauthenticatedFilter;
1011

1112
import org.junit.AfterClass;
1213
import org.junit.BeforeClass;
@@ -60,6 +61,50 @@ public void filters_onStartupMethod_executeFilters() {
6061

6162
}
6263

64+
@Test
65+
public void filters_unauthenticatedFilter_stopRequestProcessing() {
66+
67+
SparkLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = null;
68+
try {
69+
handler = SparkLambdaContainerHandler.getAwsProxyHandler();
70+
} catch (ContainerInitializationException e) {
71+
e.printStackTrace();
72+
fail();
73+
}
74+
75+
handler.onStartup(c -> {
76+
if (c == null) {
77+
System.out.println("Null servlet context");
78+
fail();
79+
}
80+
FilterRegistration.Dynamic registration = c.addFilter("UnauthenticatedFilter", UnauthenticatedFilter.class);
81+
// update the registration to map to a path
82+
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/unauth");
83+
// servlet name mappings are disabled and will throw an exception
84+
});
85+
86+
configureRoutes();
87+
88+
// first we test without the custom header, we expect request processing to complete
89+
// successfully
90+
AwsProxyRequest req = new AwsProxyRequestBuilder().method("GET").path("/unauth").build();
91+
AwsProxyResponse response = handler.proxy(req, new MockLambdaContext());
92+
93+
assertNotNull(response);
94+
assertEquals(200, response.getStatusCode());
95+
assertEquals(RESPONSE_BODY_TEXT, response.getBody());
96+
97+
// now we test with the custom header, this should stop request processing in the
98+
// filter and return an unauthenticated response
99+
AwsProxyRequest unauthReq = new AwsProxyRequestBuilder().method("GET").path("/unauth")
100+
.header(UnauthenticatedFilter.HEADER_NAME, "1").build();
101+
AwsProxyResponse unauthResp = handler.proxy(unauthReq, new MockLambdaContext());
102+
103+
assertNotNull(unauthResp);
104+
assertEquals(UnauthenticatedFilter.RESPONSE_STATUS, unauthResp.getStatusCode());
105+
assertEquals("", unauthResp.getBody());
106+
}
107+
63108
@AfterClass
64109
public static void stopSpark() {
65110
Spark.stop();
@@ -70,5 +115,10 @@ private static void configureRoutes() {
70115
res.status(200);
71116
return RESPONSE_BODY_TEXT;
72117
});
118+
119+
get("/unauth", (req, res) -> {
120+
res.status(200);
121+
return RESPONSE_BODY_TEXT;
122+
});
73123
}
74124
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.amazonaws.serverless.proxy.spark.filter;
2+
3+
4+
import javax.servlet.Filter;
5+
import javax.servlet.FilterChain;
6+
import javax.servlet.FilterConfig;
7+
import javax.servlet.ServletException;
8+
import javax.servlet.ServletRequest;
9+
import javax.servlet.ServletResponse;
10+
import javax.servlet.http.HttpServletRequest;
11+
import javax.servlet.http.HttpServletResponse;
12+
13+
import java.io.IOException;
14+
15+
16+
public class UnauthenticatedFilter implements Filter {
17+
public static final String HEADER_NAME = "X-Unauthenticated-Response";
18+
public static final int RESPONSE_STATUS = 401;
19+
20+
@Override
21+
public void init(FilterConfig filterConfig)
22+
throws ServletException {
23+
24+
}
25+
26+
27+
@Override
28+
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
29+
throws IOException, ServletException {
30+
System.out.println("Running unauth filter");
31+
if (((HttpServletRequest)servletRequest).getHeader(HEADER_NAME) != null) {
32+
((HttpServletResponse) servletResponse).setStatus(401);
33+
System.out.println("Returning 401");
34+
return;
35+
}
36+
System.out.println("Continue chain");
37+
filterChain.doFilter(servletRequest, servletResponse);
38+
}
39+
40+
41+
@Override
42+
public void destroy() {
43+
44+
}
45+
}

0 commit comments

Comments
 (0)