diff --git a/src/main/java/io/dropwizard/web/conf/CorsFilterFactory.java b/src/main/java/io/dropwizard/web/conf/CorsFilterFactory.java index c301bf7..27e36bb 100644 --- a/src/main/java/io/dropwizard/web/conf/CorsFilterFactory.java +++ b/src/main/java/io/dropwizard/web/conf/CorsFilterFactory.java @@ -1,19 +1,18 @@ package io.dropwizard.web.conf; import com.fasterxml.jackson.annotation.JsonProperty; -import io.dropwizard.jetty.setup.ServletEnvironment; import io.dropwizard.core.setup.Environment; import io.dropwizard.util.Duration; -import java.util.Collections; -import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.FilterRegistration; -import org.eclipse.jetty.ee10.servlets.CrossOriginFilter; +import java.util.Set; +import org.eclipse.jetty.ee10.servlet.ServletContextHandler; +import org.eclipse.jetty.http.pathmap.PathSpec; +import org.eclipse.jetty.server.handler.CrossOriginHandler; +import org.eclipse.jetty.server.handler.PathMappingsHandler; public class CorsFilterFactory { @JsonProperty @@ -98,45 +97,46 @@ public void setChainPreflight(boolean chainPreflight) { } public void build(Environment environment, String urlPattern) { - // build map of init parameters - final Map builder = new HashMap<>(); + + final CrossOriginHandler corsHandler = new CrossOriginHandler(); if (allowedOrigins != null && !allowedOrigins.isEmpty()) { - builder.put(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, String.join(",", allowedOrigins)); + corsHandler.setAllowedOriginPatterns(Set.copyOf(allowedOrigins)); } if (allowedTimingOrigins != null && !allowedTimingOrigins.isEmpty()) { - builder.put(CrossOriginFilter.ALLOWED_TIMING_ORIGINS_PARAM, String.join(",", allowedTimingOrigins)); + corsHandler.setAllowedTimingOriginPatterns(Set.copyOf(allowedTimingOrigins)); } if (allowedMethods != null && !allowedMethods.isEmpty()) { - builder.put(CrossOriginFilter.ALLOWED_METHODS_PARAM, String.join(",", allowedMethods)); + corsHandler.setAllowedMethods(Set.copyOf(allowedMethods)); } if (allowedHeaders != null && !allowedHeaders.isEmpty()) { - builder.put(CrossOriginFilter.ALLOWED_HEADERS_PARAM, String.join(",", allowedHeaders)); + corsHandler.setAllowedHeaders(Set.copyOf(allowedHeaders)); } if (preflightMaxAge != null) { - builder.put(CrossOriginFilter.PREFLIGHT_MAX_AGE_PARAM, String.valueOf(preflightMaxAge.toSeconds())); + corsHandler.setPreflightMaxAge(preflightMaxAge.toJavaDuration()); } if (allowCredentials != null) { - builder.put(CrossOriginFilter.ALLOW_CREDENTIALS_PARAM, String.valueOf(allowCredentials)); + corsHandler.setAllowCredentials(allowCredentials); } if (exposedHeaders != null && !exposedHeaders.isEmpty()) { - builder.put(CrossOriginFilter.EXPOSED_HEADERS_PARAM, String.join(",", exposedHeaders)); + setExposedHeaders(exposedHeaders); } if (chainPreflight != null) { - builder.put(CrossOriginFilter.CHAIN_PREFLIGHT_PARAM, String.valueOf(chainPreflight)); + setChainPreflight(chainPreflight); } - // configure filter - final ServletEnvironment servlets = environment.servlets(); - final FilterRegistration.Dynamic cors = servlets.addFilter("cross-origin-filter", CrossOriginFilter.class); - cors.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, urlPattern); - cors.setInitParameters(Collections.unmodifiableMap(builder)); + final PathMappingsHandler pathHandler = new PathMappingsHandler(); + pathHandler.addMapping(PathSpec.from(urlPattern), corsHandler); + + final ServletContextHandler contextHandler = environment.getApplicationContext(); + contextHandler.insertHandler(corsHandler); + } } diff --git a/src/test/java/io/dropwizard/web/conf/CorsFilterFactoryTest.java b/src/test/java/io/dropwizard/web/conf/CorsFilterFactoryTest.java index 66c2c16..932b62f 100644 --- a/src/test/java/io/dropwizard/web/conf/CorsFilterFactoryTest.java +++ b/src/test/java/io/dropwizard/web/conf/CorsFilterFactoryTest.java @@ -1,10 +1,11 @@ package io.dropwizard.web.conf; import com.google.common.collect.ImmutableList; -import io.dropwizard.jetty.setup.ServletEnvironment; +import io.dropwizard.jetty.MutableServletContextHandler; import io.dropwizard.core.setup.Environment; -import org.eclipse.jetty.ee10.servlets.CrossOriginFilter; -import org.hamcrest.CoreMatchers; +import java.util.Set; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.handler.CrossOriginHandler; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -12,23 +13,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import java.util.Collections; -import java.util.EnumSet; -import java.util.Map; - -import jakarta.servlet.Filter; -import jakarta.servlet.FilterRegistration; - import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.aMapWithSize; -import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; - +import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -37,17 +25,9 @@ public class CorsFilterFactoryTest { @Mock Environment env; @Mock - ServletEnvironment servlets; - @Mock - FilterRegistration.Dynamic registration; - @Captor - ArgumentCaptor filterNameCaptor; + MutableServletContextHandler contextHandler; @Captor - ArgumentCaptor> filterClassCaptor; - @Captor - ArgumentCaptor urlPatternCaptor; - @Captor - ArgumentCaptor> initParamsCaptor; + ArgumentCaptor handlerCaptor; @BeforeEach public void setUp() throws Exception { @@ -55,12 +35,10 @@ public void setUp() throws Exception { } @Test - public void configureFilter() { + public void configureHandler() { // given - when(env.servlets()).thenReturn(servlets); - when(servlets.addFilter(anyString(), any(Class.class))).thenReturn(registration); - doNothing().when(registration).addMappingForUrlPatterns(any(EnumSet.class), anyBoolean(), anyString()); - when(registration.setInitParameters(anyMap())).thenReturn(Collections.emptySet()); + when(env.getApplicationContext()).thenReturn(contextHandler); + doNothing().when(contextHandler).insertHandler(any(Handler.Singleton.class)); String urlPattern = "/example/*"; CorsFilterFactory factory = new CorsFilterFactory(); factory.setAllowedOrigins(ImmutableList.of("example.com", "foo.com")); @@ -69,13 +47,13 @@ public void configureFilter() { factory.build(env, urlPattern); // then - verify(servlets).addFilter(filterNameCaptor.capture(), filterClassCaptor.capture()); - verify(registration).addMappingForUrlPatterns(any(), eq(true), urlPatternCaptor.capture()); - verify(registration).setInitParameters(initParamsCaptor.capture()); - assertThat(filterNameCaptor.getValue(), is("cross-origin-filter")); - assertThat(filterClassCaptor.getValue(), is(CoreMatchers.>equalTo(CrossOriginFilter.class))); - assertThat(urlPatternCaptor.getValue(), is(urlPattern)); - assertThat(initParamsCaptor.getValue(), aMapWithSize(1)); - assertThat(initParamsCaptor.getValue(), hasEntry(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, "example.com,foo.com")); + verify(contextHandler).insertHandler(handlerCaptor.capture()); + assertThat(handlerCaptor.getValue(), is(notNullValue())); + assertThat(handlerCaptor.getValue() instanceof CrossOriginHandler, is(true)); + assertThat(((CrossOriginHandler) handlerCaptor.getValue()) + .getAllowedOriginPatterns() + , is(Set.of("example.com", "foo.com"))); + + } }