Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

@SuppressWarnings("checkstyle:LineLength")
/**
* A request filter that blocks malicious requests. Invalid request will respond with a 400 response code.
* <p>
Expand All @@ -40,15 +38,19 @@
* <li>Semicolon - can be disabled by setting {@code blockSemicolon = false}</li>
* <li>Backslash - can be disabled by setting {@code blockBackslash = false}</li>
* <li>Non-ASCII characters - can be disabled by setting {@code blockNonAscii = false},
* the ability to disable this check will be removed in future version.</li>
* the ability to disable this check will be removed in future version.</li>
* <li>Path traversals - can be disabled by setting {@code blockTraversal = false}</li>
* </ul>
*
* @see <a href="https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/firewall/StrictHttpFirewall.html">
* This class was inspired by Spring Security StrictHttpFirewall</a>
* This class was inspired by Spring Security StrictHttpFirewall
* @since 1.6
*/
public class InvalidRequestFilter extends AccessControlFilter {
public enum PathTraversalBlockMode {
STRICT,
NORMAL,
NO_BLOCK;
}

private static final List<String> SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));

Expand All @@ -64,35 +66,27 @@ public class InvalidRequestFilter extends AccessControlFilter {

private boolean blockNonAscii = true;

private boolean blockTraversal = true;

private boolean blockEncodedPeriod = true;

private boolean blockEncodedForwardSlash = true;

private boolean blockRewriteTraversal = true;
private PathTraversalBlockMode pathTraversalBlockMode = PathTraversalBlockMode.NORMAL;

@Override
protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception {
HttpServletRequest request = WebUtils.toHttp(req);
// check the original and decoded values

// user request string (not decoded)
return isValid(request.getRequestURI())
// decoded servlet part
&& isValid(request.getServletPath())
// decoded path info (may be null)
// decoded path info (maybe null)
&& isValid(request.getPathInfo());
}

@SuppressWarnings("checkstyle:BooleanExpressionComplexity")
private boolean isValid(String uri) {
return !StringUtils.hasText(uri)
|| (!containsSemicolon(uri)
&& !containsBackslash(uri)
&& !containsNonAsciiCharacters(uri)
&& !containsTraversal(uri)
&& !containsEncodedPeriods(uri)
&& !containsEncodedForwardSlash(uri));
|| (!containsSemicolon(uri)
&& !containsBackslash(uri)
&& !containsNonAsciiCharacters(uri))
&& !containsTraversal(uri);
}

@Override
Expand Down Expand Up @@ -134,23 +128,13 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
}

private boolean containsTraversal(String uri) {
if (isBlockTraversal()) {
return !isNormalized(uri)
|| (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains));
if (pathTraversalBlockMode == PathTraversalBlockMode.NORMAL) {
return !(isNormalized(uri));
}
return false;
}

private boolean containsEncodedPeriods(String uri) {
if (isBlockEncodedPeriod()) {
return PERIOD.stream().anyMatch(uri::contains);
}
return false;
}

private boolean containsEncodedForwardSlash(String uri) {
if (isBlockEncodedForwardSlash()) {
return FORWARDSLASH.stream().anyMatch(uri::contains);
if (pathTraversalBlockMode == PathTraversalBlockMode.STRICT) {
return !(isNormalized(uri)
&& PERIOD.stream().noneMatch(uri::contains)
&& FORWARDSLASH.stream().noneMatch(uri::contains));
}
return false;
}
Expand Down Expand Up @@ -205,35 +189,52 @@ public void setBlockNonAscii(boolean blockNonAscii) {
this.blockNonAscii = blockNonAscii;
}

public boolean isBlockTraversal() {
return blockTraversal;
public PathTraversalBlockMode getPathTraversalBlockMode() {
return pathTraversalBlockMode;
}

public void setBlockTraversal(boolean blockTraversal) {
this.blockTraversal = blockTraversal;
public void setBlockPathTraversal(PathTraversalBlockMode mode) {
this.pathTraversalBlockMode = mode;
}

public boolean isBlockEncodedPeriod() {
return blockEncodedPeriod;
return pathTraversalBlockMode == PathTraversalBlockMode.STRICT;
}

public void setBlockEncodedPeriod(boolean blockEncodedPeriod) {
this.blockEncodedPeriod = blockEncodedPeriod;
setBlockPathTraversal(blockEncodedPeriod ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL);
}

public boolean isBlockEncodedForwardSlash() {
return blockEncodedForwardSlash;
return pathTraversalBlockMode == PathTraversalBlockMode.STRICT;
}

public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) {
this.blockEncodedForwardSlash = blockEncodedForwardSlash;
setBlockPathTraversal(blockEncodedForwardSlash ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL);
}

public boolean isBlockRewriteTraversal() {
return blockRewriteTraversal;
return pathTraversalBlockMode == PathTraversalBlockMode.NORMAL;
}

public void setBlockRewriteTraversal(boolean blockRewriteTraversal) {
this.blockRewriteTraversal = blockRewriteTraversal;
setBlockPathTraversal(blockRewriteTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK);
}

/**
* @deprecated use {@link #getPathTraversalBlockMode()} instead
*/
@Deprecated
public boolean isBlockTraversal() {
return pathTraversalBlockMode != PathTraversalBlockMode.NO_BLOCK;
}

/**
*
* @deprecated Use {@link #setBlockPathTraversal(PathTraversalBlockMode)}
*/
@Deprecated
public void setBlockTraversal(boolean blockTraversal) {
this.pathTraversalBlockMode = blockTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ class InvalidRequestFilterTest {
assertThat "filter.blockBackslash expected to be true", filter.isBlockBackslash()
assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii()
assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon()
assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal()
assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal()
assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod()
assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash()
assertThat "filter.blockTraversal expected to be NORMAL",
filter.getPathTraversalBlockMode() == InvalidRequestFilter.PathTraversalBlockMode.NORMAL
}

@Test
Expand All @@ -63,6 +61,7 @@ class InvalidRequestFilterTest {
}
}


@Test
void testFilterBlocks() {
InvalidRequestFilter filter = new InvalidRequestFilter()
Expand All @@ -76,11 +75,10 @@ class InvalidRequestFilterTest {

assertPathBlocked(filter, "/something", "/;something")
assertPathBlocked(filter, "/something", "/something", "/;")
assertPathBlocked(filter, "/something", "/something", "/.;")
}

@Test
void testBlocksTraversal() {
void testBlocksTraversalNormal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "/something/../")
assertPathBlocked(filter, "/something/../bar")
Expand All @@ -89,77 +87,49 @@ class InvalidRequestFilterTest {
assertPathBlocked(filter, "/..")
assertPathBlocked(filter, "..")
assertPathBlocked(filter, "../")
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/./")
assertPathBlocked(filter, "/something/./bar")
assertPathBlocked(filter, "/something/\u002e/bar")
assertPathBlocked(filter, "/something/./bar/")
assertPathBlocked(filter, "/something/.")
assertPathBlocked(filter, "/.")
assertPathBlocked(filter, "/something/../something/.")
assertPathBlocked(filter, "/something/../something/.")
assertPathBlocked(filter, "/something/.;")
assertPathBlocked(filter, "/something/%2e%3b")

assertPathAllowed(filter, "/something/.bar")
assertPathAllowed(filter, "/.something")
assertPathAllowed(filter, ".something")
}

@Test
void testBlocksEncodedPeriod() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "/%2esomething")
assertPathBlocked(filter, "%2esomething")
assertPathBlocked(filter, "%2E./")
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/%2e;")
assertPathBlocked(filter, "/something/%2e%3b")
assertPathBlocked(filter, "/something/%2e%2E/bar/")
assertPathBlocked(filter, "/something/%2e/bar/")
}

@Test
void testAllowsEncodedPeriod() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockEncodedPeriod(false)
assertPathAllowed(filter, "/%2esomething")
assertPathAllowed(filter, "%2esomething")
assertPathAllowed(filter, "%2E./")
assertPathAllowed(filter, "/something/%2e%2E/bar/")
assertPathAllowed(filter, "/something/%2e/bar/")
}

@Test
void testBlocksEncodedForwardSlash() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/%2f/bar/")
}

@Test
void testAllowsEncodedForwardSlash() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockEncodedForwardSlash(false)
assertPathAllowed(filter, "%2F./")
assertPathAllowed(filter, "/something/%2e/bar/")
assertPathAllowed(filter, "/something/%2f/bar/")
assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
assertPathAllowed(filter, "/something/%2e%2E/bar/")
assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}

@Test
void testBlocksRewriteTraversal() {
void testBlocksTraversalStrict() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockSemicolon(false)
assertPathBlocked(filter, "/something/..;jsessionid=foobar")
assertPathBlocked(filter, "/something/.;jsessionid=foobar")
}
filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.STRICT)
assertPathBlocked(filter, "/something/../")
assertPathBlocked(filter, "/something/../bar")
assertPathBlocked(filter, "/something/../bar/")
assertPathBlocked(filter, "/something/..")
assertPathBlocked(filter, "/..")
assertPathBlocked(filter, "..")
assertPathBlocked(filter, "../")
assertPathBlocked(filter, "/something/./")
assertPathBlocked(filter, "/something/./bar")
assertPathBlocked(filter, "/something/\u002e/bar")
assertPathBlocked(filter, "/something/./bar/")
assertPathBlocked(filter, "/something/.")
assertPathBlocked(filter, "/.")
assertPathBlocked(filter, "/something/../something/.")

@Test
void testAllowRewriteTraversal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockSemicolon(false)
filter.setBlockRewriteTraversal(false)
assertPathAllowed(filter, "/something/..;jsessionid=foobar")
assertPathAllowed(filter, "/something/.;jsessionid=foobar")
assertPathBlocked(filter, "%2E./")
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/%2e/bar/")
assertPathBlocked(filter, "/something/%2f/bar/")
assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
assertPathBlocked(filter, "/something/%2e%2E/bar/")
assertPathBlocked(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}

@Test
Expand Down Expand Up @@ -213,7 +183,7 @@ class InvalidRequestFilterTest {
@Test
void testAllowTraversal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockTraversal(false)
filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.NO_BLOCK);

assertPathAllowed(filter, "/something/../")
assertPathAllowed(filter, "/something/../bar")
Expand All @@ -230,6 +200,14 @@ class InvalidRequestFilterTest {
assertPathAllowed(filter, "/something/.")
assertPathAllowed(filter, "/.")
assertPathAllowed(filter, "/something/../something/.")

assertPathAllowed(filter, "%2E./")
assertPathAllowed(filter, "%2F./")
assertPathAllowed(filter, "/something/%2e/bar/")
assertPathAllowed(filter, "/something/%2f/bar/")
assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
assertPathAllowed(filter, "/something/%2e%2E/bar/")
assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}

static void assertPathBlocked(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {
Expand Down