diff --git a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java
index e29bd3eb62..ff9f9824de 100644
--- a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java
+++ b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java
@@ -29,9 +29,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.stream.Stream;
-@SuppressWarnings("checkstyle:LineLength")
/**
* A request filter that blocks malicious requests. Invalid request will respond with a 400 response code.
*
@@ -40,15 +38,19 @@
*
Semicolon - can be disabled by setting {@code blockSemicolon = false}
* Backslash - can be disabled by setting {@code blockBackslash = false}
* Non-ASCII characters - can be disabled by setting {@code blockNonAscii = false},
- * the ability to disable this check will be removed in future version.
+ * the ability to disable this check will be removed in future version.
* Path traversals - can be disabled by setting {@code blockTraversal = false}
*
*
- * @see
- * This class was inspired by Spring Security StrictHttpFirewall
+ * This class was inspired by Spring Security StrictHttpFirewall
* @since 1.6
*/
public class InvalidRequestFilter extends AccessControlFilter {
+ public enum PathTraversalBlockMode {
+ STRICT,
+ NORMAL,
+ NO_BLOCK;
+ }
private static final List SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));
@@ -64,35 +66,27 @@ public class InvalidRequestFilter extends AccessControlFilter {
private boolean blockNonAscii = true;
- private boolean blockTraversal = true;
-
- private boolean blockEncodedPeriod = true;
-
- private boolean blockEncodedForwardSlash = true;
-
- private boolean blockRewriteTraversal = true;
+ private PathTraversalBlockMode pathTraversalBlockMode = PathTraversalBlockMode.NORMAL;
@Override
protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception {
HttpServletRequest request = WebUtils.toHttp(req);
// check the original and decoded values
+
// user request string (not decoded)
return isValid(request.getRequestURI())
// decoded servlet part
&& isValid(request.getServletPath())
- // decoded path info (may be null)
+ // decoded path info (maybe null)
&& isValid(request.getPathInfo());
}
- @SuppressWarnings("checkstyle:BooleanExpressionComplexity")
private boolean isValid(String uri) {
return !StringUtils.hasText(uri)
- || (!containsSemicolon(uri)
- && !containsBackslash(uri)
- && !containsNonAsciiCharacters(uri)
- && !containsTraversal(uri)
- && !containsEncodedPeriods(uri)
- && !containsEncodedForwardSlash(uri));
+ || (!containsSemicolon(uri)
+ && !containsBackslash(uri)
+ && !containsNonAsciiCharacters(uri))
+ && !containsTraversal(uri);
}
@Override
@@ -134,23 +128,13 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
}
private boolean containsTraversal(String uri) {
- if (isBlockTraversal()) {
- return !isNormalized(uri)
- || (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains));
+ if (pathTraversalBlockMode == PathTraversalBlockMode.NORMAL) {
+ return !(isNormalized(uri));
}
- return false;
- }
-
- private boolean containsEncodedPeriods(String uri) {
- if (isBlockEncodedPeriod()) {
- return PERIOD.stream().anyMatch(uri::contains);
- }
- return false;
- }
-
- private boolean containsEncodedForwardSlash(String uri) {
- if (isBlockEncodedForwardSlash()) {
- return FORWARDSLASH.stream().anyMatch(uri::contains);
+ if (pathTraversalBlockMode == PathTraversalBlockMode.STRICT) {
+ return !(isNormalized(uri)
+ && PERIOD.stream().noneMatch(uri::contains)
+ && FORWARDSLASH.stream().noneMatch(uri::contains));
}
return false;
}
@@ -205,35 +189,52 @@ public void setBlockNonAscii(boolean blockNonAscii) {
this.blockNonAscii = blockNonAscii;
}
- public boolean isBlockTraversal() {
- return blockTraversal;
+ public PathTraversalBlockMode getPathTraversalBlockMode() {
+ return pathTraversalBlockMode;
}
- public void setBlockTraversal(boolean blockTraversal) {
- this.blockTraversal = blockTraversal;
+ public void setBlockPathTraversal(PathTraversalBlockMode mode) {
+ this.pathTraversalBlockMode = mode;
}
public boolean isBlockEncodedPeriod() {
- return blockEncodedPeriod;
+ return pathTraversalBlockMode == PathTraversalBlockMode.STRICT;
}
public void setBlockEncodedPeriod(boolean blockEncodedPeriod) {
- this.blockEncodedPeriod = blockEncodedPeriod;
+ setBlockPathTraversal(blockEncodedPeriod ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL);
}
public boolean isBlockEncodedForwardSlash() {
- return blockEncodedForwardSlash;
+ return pathTraversalBlockMode == PathTraversalBlockMode.STRICT;
}
public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) {
- this.blockEncodedForwardSlash = blockEncodedForwardSlash;
+ setBlockPathTraversal(blockEncodedForwardSlash ? PathTraversalBlockMode.STRICT : PathTraversalBlockMode.NORMAL);
}
public boolean isBlockRewriteTraversal() {
- return blockRewriteTraversal;
+ return pathTraversalBlockMode == PathTraversalBlockMode.NORMAL;
}
public void setBlockRewriteTraversal(boolean blockRewriteTraversal) {
- this.blockRewriteTraversal = blockRewriteTraversal;
+ setBlockPathTraversal(blockRewriteTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK);
+ }
+
+ /**
+ * @deprecated use {@link #getPathTraversalBlockMode()} instead
+ */
+ @Deprecated
+ public boolean isBlockTraversal() {
+ return pathTraversalBlockMode != PathTraversalBlockMode.NO_BLOCK;
+ }
+
+ /**
+ *
+ * @deprecated Use {@link #setBlockPathTraversal(PathTraversalBlockMode)}
+ */
+ @Deprecated
+ public void setBlockTraversal(boolean blockTraversal) {
+ this.pathTraversalBlockMode = blockTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK;
}
}
diff --git a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy
index a046670d36..bb14e5395d 100644
--- a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy
+++ b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy
@@ -39,10 +39,8 @@ class InvalidRequestFilterTest {
assertThat "filter.blockBackslash expected to be true", filter.isBlockBackslash()
assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii()
assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon()
- assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal()
- assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal()
- assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod()
- assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash()
+ assertThat "filter.blockTraversal expected to be NORMAL",
+ filter.getPathTraversalBlockMode() == InvalidRequestFilter.PathTraversalBlockMode.NORMAL
}
@Test
@@ -63,6 +61,7 @@ class InvalidRequestFilterTest {
}
}
+
@Test
void testFilterBlocks() {
InvalidRequestFilter filter = new InvalidRequestFilter()
@@ -76,11 +75,10 @@ class InvalidRequestFilterTest {
assertPathBlocked(filter, "/something", "/;something")
assertPathBlocked(filter, "/something", "/something", "/;")
- assertPathBlocked(filter, "/something", "/something", "/.;")
}
@Test
- void testBlocksTraversal() {
+ void testBlocksTraversalNormal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "/something/../")
assertPathBlocked(filter, "/something/../bar")
@@ -89,7 +87,6 @@ class InvalidRequestFilterTest {
assertPathBlocked(filter, "/..")
assertPathBlocked(filter, "..")
assertPathBlocked(filter, "../")
- assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/./")
assertPathBlocked(filter, "/something/./bar")
assertPathBlocked(filter, "/something/\u002e/bar")
@@ -97,69 +94,42 @@ class InvalidRequestFilterTest {
assertPathBlocked(filter, "/something/.")
assertPathBlocked(filter, "/.")
assertPathBlocked(filter, "/something/../something/.")
- assertPathBlocked(filter, "/something/../something/.")
- assertPathBlocked(filter, "/something/.;")
- assertPathBlocked(filter, "/something/%2e%3b")
-
- assertPathAllowed(filter, "/something/.bar")
- assertPathAllowed(filter, "/.something")
- assertPathAllowed(filter, ".something")
- }
- @Test
- void testBlocksEncodedPeriod() {
- InvalidRequestFilter filter = new InvalidRequestFilter()
- assertPathBlocked(filter, "/%2esomething")
- assertPathBlocked(filter, "%2esomething")
- assertPathBlocked(filter, "%2E./")
- assertPathBlocked(filter, "%2F./")
- assertPathBlocked(filter, "/something/%2e;")
- assertPathBlocked(filter, "/something/%2e%3b")
- assertPathBlocked(filter, "/something/%2e%2E/bar/")
- assertPathBlocked(filter, "/something/%2e/bar/")
- }
-
- @Test
- void testAllowsEncodedPeriod() {
- InvalidRequestFilter filter = new InvalidRequestFilter()
- filter.setBlockEncodedPeriod(false)
- assertPathAllowed(filter, "/%2esomething")
- assertPathAllowed(filter, "%2esomething")
assertPathAllowed(filter, "%2E./")
- assertPathAllowed(filter, "/something/%2e%2E/bar/")
- assertPathAllowed(filter, "/something/%2e/bar/")
- }
-
- @Test
- void testBlocksEncodedForwardSlash() {
- InvalidRequestFilter filter = new InvalidRequestFilter()
- assertPathBlocked(filter, "%2F./")
- assertPathBlocked(filter, "/something/%2f/bar/")
- }
-
- @Test
- void testAllowsEncodedForwardSlash() {
- InvalidRequestFilter filter = new InvalidRequestFilter()
- filter.setBlockEncodedForwardSlash(false)
assertPathAllowed(filter, "%2F./")
+ assertPathAllowed(filter, "/something/%2e/bar/")
assertPathAllowed(filter, "/something/%2f/bar/")
+ assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
+ assertPathAllowed(filter, "/something/%2e%2E/bar/")
+ assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}
@Test
- void testBlocksRewriteTraversal() {
+ void testBlocksTraversalStrict() {
InvalidRequestFilter filter = new InvalidRequestFilter()
- filter.setBlockSemicolon(false)
- assertPathBlocked(filter, "/something/..;jsessionid=foobar")
- assertPathBlocked(filter, "/something/.;jsessionid=foobar")
- }
+ filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.STRICT)
+ assertPathBlocked(filter, "/something/../")
+ assertPathBlocked(filter, "/something/../bar")
+ assertPathBlocked(filter, "/something/../bar/")
+ assertPathBlocked(filter, "/something/..")
+ assertPathBlocked(filter, "/..")
+ assertPathBlocked(filter, "..")
+ assertPathBlocked(filter, "../")
+ assertPathBlocked(filter, "/something/./")
+ assertPathBlocked(filter, "/something/./bar")
+ assertPathBlocked(filter, "/something/\u002e/bar")
+ assertPathBlocked(filter, "/something/./bar/")
+ assertPathBlocked(filter, "/something/.")
+ assertPathBlocked(filter, "/.")
+ assertPathBlocked(filter, "/something/../something/.")
- @Test
- void testAllowRewriteTraversal() {
- InvalidRequestFilter filter = new InvalidRequestFilter()
- filter.setBlockSemicolon(false)
- filter.setBlockRewriteTraversal(false)
- assertPathAllowed(filter, "/something/..;jsessionid=foobar")
- assertPathAllowed(filter, "/something/.;jsessionid=foobar")
+ assertPathBlocked(filter, "%2E./")
+ assertPathBlocked(filter, "%2F./")
+ assertPathBlocked(filter, "/something/%2e/bar/")
+ assertPathBlocked(filter, "/something/%2f/bar/")
+ assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
+ assertPathBlocked(filter, "/something/%2e%2E/bar/")
+ assertPathBlocked(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}
@Test
@@ -213,7 +183,7 @@ class InvalidRequestFilterTest {
@Test
void testAllowTraversal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
- filter.setBlockTraversal(false)
+ filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.NO_BLOCK);
assertPathAllowed(filter, "/something/../")
assertPathAllowed(filter, "/something/../bar")
@@ -230,6 +200,14 @@ class InvalidRequestFilterTest {
assertPathAllowed(filter, "/something/.")
assertPathAllowed(filter, "/.")
assertPathAllowed(filter, "/something/../something/.")
+
+ assertPathAllowed(filter, "%2E./")
+ assertPathAllowed(filter, "%2F./")
+ assertPathAllowed(filter, "/something/%2e/bar/")
+ assertPathAllowed(filter, "/something/%2f/bar/")
+ assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
+ assertPathAllowed(filter, "/something/%2e%2E/bar/")
+ assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
}
static void assertPathBlocked(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {