Skip to content

Commit ee67106

Browse files
committed
More unit tests, annotation support, bug fixes
Bug fixes and additional unit tests. Annotation support completes the feature required to address #15. The service does not automatically load annotated filter classes, these will have to be added manually with the new startup event.
1 parent d7bec0a commit ee67106

File tree

8 files changed

+414
-18
lines changed

8 files changed

+414
-18
lines changed

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,11 @@ public FilterRegistration.Dynamic addFilter(String name, Filter filter) {
354354
return null;
355355
}
356356

357-
FilterHolder newFilter = new FilterHolder(name, filter, this);
358-
filters.put(name, newFilter);
357+
FilterHolder newFilter = new FilterHolder(filter, this);
358+
if (!"".equals(name.trim())) {
359+
newFilter.setFilterName(name);
360+
}
361+
filters.put(newFilter.getFilterName(), newFilter);
359362
return newFilter.getRegistration();
360363
}
361364

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,31 @@ public void addFilter(FilterHolder newFilter) {
6262
filters.add(newFilter);
6363
}
6464

65+
/**
66+
* Returns the number of filters loaded in the chain holder
67+
* @return The number of filters in the chain holder. If the filter chain is null then this will return 0
68+
*/
69+
public int filterCount() {
70+
if (filters == null) {
71+
return 0;
72+
} else {
73+
return filters.size();
74+
}
75+
}
76+
77+
/**
78+
* Get the <code>FilterHolder</code> object from the chain at the given index.
79+
* @param idx The index in the chain. Use the <code>filterCount</code> method to get the filter count
80+
* @return A populated FilterHolder object
81+
*/
82+
public FilterHolder getFilter(int idx) {
83+
if (filters == null) {
84+
return null;
85+
} else {
86+
return filters.get(idx);
87+
}
88+
}
89+
6590
//-------------------------------------------------------------
6691
// Implementation - FilterChain
6792
//-------------------------------------------------------------

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public abstract class FilterChainManager<ServletContextType> {
3535
// we use the synchronizedMap because we do not expect high concurrency on this object. Lambda only allows one
3636
// event at a time per container
3737
private Map<TargetCacheKey, FilterChainHolder> filterCache = Collections.synchronizedMap(new HashMap<TargetCacheKey, FilterChainHolder>());
38+
private int filtersSize = -1;
3839

3940
//-------------------------------------------------------------
4041
// Variables - Protected
@@ -79,7 +80,11 @@ public FilterChainHolder getFilterChain(HttpServletRequest request) {
7980
String targetPath = request.getServletPath();
8081
DispatcherType type = request.getDispatcherType();
8182

82-
if (getFilterChainCache(type, targetPath) != null) {
83+
if (filtersSize == -1) {
84+
getFilterHolders().size();
85+
}
86+
// only return the cached result if the filter list hasn't changed in the meanwhile
87+
if (getFilterHolders().size() == filtersSize && getFilterChainCache(type, targetPath) != null) {
8388
return getFilterChainCache(type, targetPath);
8489
}
8590

@@ -107,6 +112,10 @@ public FilterChainHolder getFilterChain(HttpServletRequest request) {
107112
}
108113

109114
putFilterChainCache(type, targetPath, chainHolder);
115+
// update total filter size
116+
if (filtersSize != registrations.size()) {
117+
filtersSize = registrations.size();
118+
}
110119
return chainHolder;
111120
}
112121

@@ -188,13 +197,13 @@ protected boolean pathMatches(String target, String mapping) {
188197
return false;
189198
}
190199
// the exact work doesn't match the and holder is not a wildcard
191-
if (!targetParts[i].equals(mappingParts[i]) && !mappingParts[i].equals("*")) {
192-
return false;
193-
} else {
194-
// stop the loop, we have found a wildcard and all path parts prior to this matched.
195-
break;
200+
if (!targetParts[i].equals(mappingParts[i])) {
201+
if (mappingParts[i].equals("*")) {
202+
break;
203+
} else {
204+
return false;
205+
}
196206
}
197-
198207
}
199208

200209
return true;
@@ -204,7 +213,7 @@ protected boolean pathMatches(String target, String mapping) {
204213
* Object used as a key for the filter chain cache. It contains a target path and dispatcher type property. It overrides
205214
* the default <code>hashCode</code> and <code>equals</code> methods to return a consistent hash for comparison.
206215
*/
207-
private class TargetCacheKey {
216+
protected static class TargetCacheKey {
208217
private String targetPath;
209218
private DispatcherType dispatcherType;
210219

@@ -224,12 +233,34 @@ public void setDispatcherType(DispatcherType dispatcherType) {
224233
this.dispatcherType = dispatcherType;
225234
}
226235

236+
/**
237+
* The hash code for a cache key is calculated using the target path and dispatcher type. First, the target path
238+
* is cleaned following these rules:
239+
* 1. trim white spaces
240+
* 2. Add "/" as first character if not there
241+
* 3. Remove "/" as last character if it is there
242+
*
243+
* Once the path is cleaned, a string in the form of TARGET_PATH:DISPATCHER_TYPE is generated and used for the
244+
* hash code
245+
* @return An int representing the hash code fo the generated string
246+
*/
227247
@Override
228248
public int hashCode() {
229249
if (targetPath == null || dispatcherType == null) {
230250
return -1;
231251
}
232-
return (targetPath + ":" + dispatcherType.name()).hashCode();
252+
253+
// clean up path
254+
String hashString = targetPath.trim();
255+
if (hashString.endsWith(PATH_PART_SEPARATOR)) {
256+
hashString = hashString.substring(0, hashString.length() - 1);
257+
}
258+
if (!hashString.startsWith(PATH_PART_SEPARATOR)) {
259+
hashString = PATH_PART_SEPARATOR + hashString;
260+
}
261+
hashString += ":" + dispatcherType.name();
262+
263+
return hashString.hashCode();
233264
}
234265

235266
@Override

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
package com.amazonaws.serverless.proxy.internal.servlet;
1414

1515
import javax.servlet.*;
16+
import javax.servlet.annotation.WebFilter;
17+
import javax.servlet.annotation.WebInitParam;
1618
import java.util.*;
1719

1820
/**
@@ -55,10 +57,26 @@ public FilterHolder(String name, Filter newFilter, ServletContext context) {
5557
filterInitialized = false;
5658
}
5759

60+
public FilterHolder(Filter newFilter, ServletContext context) {
61+
this(newFilter.getClass().getName(), newFilter, context);
62+
63+
if (isAnnotated()) {
64+
filterName = readAnnotatedFilterName();
65+
initParameters = readAnnotatedInitParams();
66+
registration = new Registration(getAnnotation());
67+
}
68+
69+
}
70+
5871
//-------------------------------------------------------------
5972
// Methods - Public
6073
//-------------------------------------------------------------
6174

75+
76+
public void setFilterName(String filterName) {
77+
this.filterName = filterName;
78+
}
79+
6280
/**
6381
* Checks whether the filter this holder is responsible for has been initialized. This method should be checked before
6482
* calling a filter, if it returns false then you should call the <code>init</code> method.
@@ -128,6 +146,47 @@ public ServletContext getServletContext() {
128146
return servletContext;
129147
}
130148

149+
public boolean isAnnotated() {
150+
return filter.getClass().isAnnotationPresent(WebFilter.class);
151+
}
152+
153+
//-------------------------------------------------------------
154+
// Methods - Private
155+
//-------------------------------------------------------------
156+
157+
private String readAnnotatedFilterName() {
158+
if (isAnnotated()) {
159+
WebFilter regAnnotation = filter.getClass().getAnnotation(WebFilter.class);
160+
if (!"".equals(regAnnotation.filterName().trim())) {
161+
return regAnnotation.filterName();
162+
} else {
163+
return filter.getClass().getName();
164+
}
165+
} else {
166+
return null;
167+
}
168+
}
169+
170+
private Map<String, String> readAnnotatedInitParams() {
171+
Map<String, String> initParams = new HashMap<>();
172+
if (isAnnotated()) {
173+
WebFilter regAnnotation = filter.getClass().getAnnotation(WebFilter.class);
174+
for (WebInitParam param : regAnnotation.initParams()) {
175+
initParams.put(param.name(), param.value());
176+
}
177+
}
178+
179+
return initParams;
180+
}
181+
182+
private WebFilter getAnnotation() {
183+
if (isAnnotated()) {
184+
return filter.getClass().getAnnotation(WebFilter.class);
185+
} else {
186+
return null;
187+
}
188+
}
189+
131190
/**
132191
* Registration class for the filter. This object stores the servlet names and the url patterns the filter is
133192
* associated with.
@@ -143,6 +202,24 @@ public Registration() {
143202
asyncSupported = false;
144203
}
145204

205+
public Registration(WebFilter annotation) {
206+
urlPatterns = new ArrayList<>();
207+
dispatcherTypes = new ArrayList<>();
208+
209+
EnumSet<DispatcherType> dispatchers = EnumSet.noneOf(DispatcherType.class);
210+
dispatchers.addAll(Arrays.asList(annotation.dispatcherTypes()));
211+
212+
if (annotation.value().length > 0) {
213+
addMappingForUrlPatterns(dispatchers, true, annotation.value());
214+
}
215+
216+
if (annotation.urlPatterns().length > 0) {
217+
addMappingForUrlPatterns(dispatchers, true, annotation.urlPatterns());
218+
}
219+
220+
asyncSupported = annotation.asyncSupported();
221+
}
222+
146223
@Override
147224
public void addMappingForServletNames(EnumSet<DispatcherType> types, boolean isLast, String... servlets) {
148225
throw new UnsupportedOperationException();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package com.amazonaws.serverless.proxy.internal.servlet.filters;
14+
15+
import javax.servlet.*;
16+
import javax.servlet.annotation.WebFilter;
17+
import javax.servlet.http.HttpServletRequest;
18+
import javax.servlet.http.HttpServletResponse;
19+
import java.io.IOException;
20+
import java.util.regex.Pattern;
21+
22+
/**
23+
* Simple path validator filter. This is a default implementation to prevent malformed paths from hitting the framework
24+
* app. This applies to all paths by default
25+
*/
26+
@WebFilter(filterName = "UrlPathValidator", urlPatterns = {"/*"})
27+
public class UrlPathValidator implements Filter {
28+
public static final int DEFAULT_ERROR_CODE = 404;
29+
public static final Pattern PATH_PATTERN = Pattern.compile("^(/[-\\w:@&?=+,.!/~*'%$_;]*)?$");
30+
private int invalidStatusCode;
31+
32+
@Override
33+
public void init(FilterConfig filterConfig) throws ServletException {
34+
if (filterConfig.getInitParameter("invalid_status_code") != null) {
35+
String statusCode = filterConfig.getInitParameter("invalid_status_code");
36+
try {
37+
invalidStatusCode = Integer.parseInt(statusCode);
38+
} catch (NumberFormatException e) {
39+
invalidStatusCode = DEFAULT_ERROR_CODE;
40+
}
41+
}
42+
}
43+
44+
@Override
45+
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
46+
// the getServletPath method of the AwsProxyHttpServletRequest returns the request path
47+
String path = ((HttpServletRequest)servletRequest).getServletPath();
48+
if (path == null) {
49+
setErrorResponse(servletResponse);
50+
return;
51+
}
52+
53+
if (!PATH_PATTERN.matcher(path).matches()) {
54+
setErrorResponse(servletResponse);
55+
return;
56+
}
57+
58+
// Logic taken from the Apache UrlValidator. I opted not to include Apache lib as a dependency to save space
59+
// in the final Lambda function package
60+
int slashCount = countStrings("/", path);
61+
int dot2Count = countStrings("..", path);
62+
int slash2Count = countStrings("//", path);
63+
if (dot2Count > 0 && (slashCount - slash2Count - 1) <= dot2Count){
64+
setErrorResponse(servletResponse);
65+
return;
66+
}
67+
68+
filterChain.doFilter(servletRequest, servletResponse);
69+
}
70+
71+
@Override
72+
public void destroy() {
73+
74+
}
75+
76+
private void setErrorResponse(ServletResponse resp) {
77+
((HttpServletResponse)resp).setStatus(invalidStatusCode);
78+
}
79+
80+
private int countStrings(String needle, String haystack) {
81+
int curIndex = 0;
82+
int stringCount = 0;
83+
84+
while (curIndex != -1) {
85+
curIndex = haystack.indexOf(needle, curIndex);
86+
if (curIndex > -1) {
87+
curIndex++;
88+
stringCount++;
89+
}
90+
}
91+
return stringCount;
92+
}
93+
}

0 commit comments

Comments
 (0)