diff --git a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java index e29bd3eb62..ff9f9824de 100644 --- a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java +++ b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java @@ -29,9 +29,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.stream.Stream; -@SuppressWarnings("checkstyle:LineLength") /** * A request filter that blocks malicious requests. Invalid request will respond with a 400 response code. *

@@ -40,15 +38,19 @@ *

  • Semicolon - can be disabled by setting {@code blockSemicolon = false}
  • *
  • Backslash - can be disabled by setting {@code blockBackslash = false}
  • *
  • Non-ASCII characters - can be disabled by setting {@code blockNonAscii = false}, - * the ability to disable this check will be removed in future version.
  • + * the ability to disable this check will be removed in future version. *
  • Path traversals - can be disabled by setting {@code blockTraversal = false}
  • * * - * @see - * This class was inspired by Spring Security StrictHttpFirewall + * This class was inspired by Spring Security StrictHttpFirewall * @since 1.6 */ public class InvalidRequestFilter extends AccessControlFilter { + public enum PathTraversalBlockMode { + STRICT, + NORMAL, + NO_BLOCK; + } private static final List SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B")); @@ -64,35 +66,27 @@ public class InvalidRequestFilter extends AccessControlFilter { private boolean blockNonAscii = true; - private boolean blockTraversal = true; - - private boolean blockEncodedPeriod = true; - - private boolean blockEncodedForwardSlash = true; - - private boolean blockRewriteTraversal = true; + private PathTraversalBlockMode pathTraversalBlockMode = PathTraversalBlockMode.NORMAL; @Override protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception { HttpServletRequest request = WebUtils.toHttp(req); // check the original and decoded values + // user request string (not decoded) return isValid(request.getRequestURI()) // decoded servlet part && isValid(request.getServletPath()) - // decoded path info (may be null) + // decoded path info (maybe null) && isValid(request.getPathInfo()); } - @SuppressWarnings("checkstyle:BooleanExpressionComplexity") private boolean isValid(String uri) { return !StringUtils.hasText(uri) - || (!containsSemicolon(uri) - && !containsBackslash(uri) - && !containsNonAsciiCharacters(uri) - && !containsTraversal(uri) - && !containsEncodedPeriods(uri) - && !containsEncodedForwardSlash(uri)); + || (!containsSemicolon(uri) + && !containsBackslash(uri) + && !containsNonAsciiCharacters(uri)) + && !containsTraversal(uri); } @Override @@ -134,23 +128,13 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) { } private boolean containsTraversal(String uri) { - if (isBlockTraversal()) { - return !isNormalized(uri) - || (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains)); + if (pathTraversalBlockMode == PathTraversalBlockMode.NORMAL) { + return !(isNormalized(uri)); } - return false; - } - - private boolean containsEncodedPeriods(String uri) { - if (isBlockEncodedPeriod()) { - return PERIOD.stream().anyMatch(uri::contains); - } - return false; - } - - private boolean containsEncodedForwardSlash(String uri) { - if (isBlockEncodedForwardSlash()) { - return FORWARDSLASH.stream().anyMatch(uri::contains); + if (pathTraversalBlockMode == PathTraversalBlockMode.STRICT) { + return !(isNormalized(uri) + && PERIOD.stream().noneMatch(uri::contains) + && FORWARDSLASH.stream().noneMatch(uri::contains)); } return false; } @@ -205,35 +189,52 @@ public void setBlockNonAscii(boolean blockNonAscii) { this.blockNonAscii = blockNonAscii; } - public boolean isBlockTraversal() { - return blockTraversal; + public PathTraversalBlockMode getPathTraversalBlockMode() { + return pathTraversalBlockMode; } - public void setBlockTraversal(boolean blockTraversal) { - this.blockTraversal = blockTraversal; + public void setBlockPathTraversal(PathTraversalBlockMode mode) { + this.pathTraversalBlockMode = mode; } public boolean isBlockEncodedPeriod() { - return blockEncodedPeriod; + return pathTraversalBlockMode == PathTraversalBlockMode.STRICT; } public void setBlockEncodedPeriod(boolean blockEncodedPeriod) { - this.blockEncodedPeriod = blockEncodedPeriod; + setBlockPathTraversal(blockEncodedPeriod ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL); } public boolean isBlockEncodedForwardSlash() { - return blockEncodedForwardSlash; + return pathTraversalBlockMode == PathTraversalBlockMode.STRICT; } public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) { - this.blockEncodedForwardSlash = blockEncodedForwardSlash; + setBlockPathTraversal(blockEncodedForwardSlash ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL); } public boolean isBlockRewriteTraversal() { - return blockRewriteTraversal; + return pathTraversalBlockMode == PathTraversalBlockMode.NORMAL; } public void setBlockRewriteTraversal(boolean blockRewriteTraversal) { - this.blockRewriteTraversal = blockRewriteTraversal; + setBlockPathTraversal(blockRewriteTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK); + } + + /** + * @deprecated use {@link #getPathTraversalBlockMode()} instead + */ + @Deprecated + public boolean isBlockTraversal() { + return pathTraversalBlockMode != PathTraversalBlockMode.NO_BLOCK; + } + + /** + * + * @deprecated Use {@link #setBlockPathTraversal(PathTraversalBlockMode)} + */ + @Deprecated + public void setBlockTraversal(boolean blockTraversal) { + this.pathTraversalBlockMode = blockTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK; } } diff --git a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy index a046670d36..bb14e5395d 100644 --- a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy +++ b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy @@ -39,10 +39,8 @@ class InvalidRequestFilterTest { assertThat "filter.blockBackslash expected to be true", filter.isBlockBackslash() assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii() assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon() - assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal() - assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal() - assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod() - assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash() + assertThat "filter.blockTraversal expected to be NORMAL", + filter.getPathTraversalBlockMode() == InvalidRequestFilter.PathTraversalBlockMode.NORMAL } @Test @@ -63,6 +61,7 @@ class InvalidRequestFilterTest { } } + @Test void testFilterBlocks() { InvalidRequestFilter filter = new InvalidRequestFilter() @@ -76,11 +75,10 @@ class InvalidRequestFilterTest { assertPathBlocked(filter, "/something", "/;something") assertPathBlocked(filter, "/something", "/something", "/;") - assertPathBlocked(filter, "/something", "/something", "/.;") } @Test - void testBlocksTraversal() { + void testBlocksTraversalNormal() { InvalidRequestFilter filter = new InvalidRequestFilter() assertPathBlocked(filter, "/something/../") assertPathBlocked(filter, "/something/../bar") @@ -89,7 +87,6 @@ class InvalidRequestFilterTest { assertPathBlocked(filter, "/..") assertPathBlocked(filter, "..") assertPathBlocked(filter, "../") - assertPathBlocked(filter, "%2F./") assertPathBlocked(filter, "/something/./") assertPathBlocked(filter, "/something/./bar") assertPathBlocked(filter, "/something/\u002e/bar") @@ -97,69 +94,42 @@ class InvalidRequestFilterTest { assertPathBlocked(filter, "/something/.") assertPathBlocked(filter, "/.") assertPathBlocked(filter, "/something/../something/.") - assertPathBlocked(filter, "/something/../something/.") - assertPathBlocked(filter, "/something/.;") - assertPathBlocked(filter, "/something/%2e%3b") - - assertPathAllowed(filter, "/something/.bar") - assertPathAllowed(filter, "/.something") - assertPathAllowed(filter, ".something") - } - @Test - void testBlocksEncodedPeriod() { - InvalidRequestFilter filter = new InvalidRequestFilter() - assertPathBlocked(filter, "/%2esomething") - assertPathBlocked(filter, "%2esomething") - assertPathBlocked(filter, "%2E./") - assertPathBlocked(filter, "%2F./") - assertPathBlocked(filter, "/something/%2e;") - assertPathBlocked(filter, "/something/%2e%3b") - assertPathBlocked(filter, "/something/%2e%2E/bar/") - assertPathBlocked(filter, "/something/%2e/bar/") - } - - @Test - void testAllowsEncodedPeriod() { - InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockEncodedPeriod(false) - assertPathAllowed(filter, "/%2esomething") - assertPathAllowed(filter, "%2esomething") assertPathAllowed(filter, "%2E./") - assertPathAllowed(filter, "/something/%2e%2E/bar/") - assertPathAllowed(filter, "/something/%2e/bar/") - } - - @Test - void testBlocksEncodedForwardSlash() { - InvalidRequestFilter filter = new InvalidRequestFilter() - assertPathBlocked(filter, "%2F./") - assertPathBlocked(filter, "/something/%2f/bar/") - } - - @Test - void testAllowsEncodedForwardSlash() { - InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockEncodedForwardSlash(false) assertPathAllowed(filter, "%2F./") + assertPathAllowed(filter, "/something/%2e/bar/") assertPathAllowed(filter, "/something/%2f/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/%2e%2E/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") } @Test - void testBlocksRewriteTraversal() { + void testBlocksTraversalStrict() { InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockSemicolon(false) - assertPathBlocked(filter, "/something/..;jsessionid=foobar") - assertPathBlocked(filter, "/something/.;jsessionid=foobar") - } + filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.STRICT) + assertPathBlocked(filter, "/something/../") + assertPathBlocked(filter, "/something/../bar") + assertPathBlocked(filter, "/something/../bar/") + assertPathBlocked(filter, "/something/..") + assertPathBlocked(filter, "/..") + assertPathBlocked(filter, "..") + assertPathBlocked(filter, "../") + assertPathBlocked(filter, "/something/./") + assertPathBlocked(filter, "/something/./bar") + assertPathBlocked(filter, "/something/\u002e/bar") + assertPathBlocked(filter, "/something/./bar/") + assertPathBlocked(filter, "/something/.") + assertPathBlocked(filter, "/.") + assertPathBlocked(filter, "/something/../something/.") - @Test - void testAllowRewriteTraversal() { - InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockSemicolon(false) - filter.setBlockRewriteTraversal(false) - assertPathAllowed(filter, "/something/..;jsessionid=foobar") - assertPathAllowed(filter, "/something/.;jsessionid=foobar") + assertPathBlocked(filter, "%2E./") + assertPathBlocked(filter, "%2F./") + assertPathBlocked(filter, "/something/%2e/bar/") + assertPathBlocked(filter, "/something/%2f/bar/") + assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathBlocked(filter, "/something/%2e%2E/bar/") + assertPathBlocked(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") } @Test @@ -213,7 +183,7 @@ class InvalidRequestFilterTest { @Test void testAllowTraversal() { InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockTraversal(false) + filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.NO_BLOCK); assertPathAllowed(filter, "/something/../") assertPathAllowed(filter, "/something/../bar") @@ -230,6 +200,14 @@ class InvalidRequestFilterTest { assertPathAllowed(filter, "/something/.") assertPathAllowed(filter, "/.") assertPathAllowed(filter, "/something/../something/.") + + assertPathAllowed(filter, "%2E./") + assertPathAllowed(filter, "%2F./") + assertPathAllowed(filter, "/something/%2e/bar/") + assertPathAllowed(filter, "/something/%2f/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/%2e%2E/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") } static void assertPathBlocked(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {