Skip to content

Commit 1822428

Browse files
wyke1jwyke
andauthored
Adapt CORS filter config to Dropwizard 5 (#656)
Co-authored-by: jwyke <jens.wyke@paypal.com>
1 parent c0fa6e4 commit 1822428

File tree

2 files changed

+39
-61
lines changed

2 files changed

+39
-61
lines changed

src/main/java/io/dropwizard/web/conf/CorsFilterFactory.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
package io.dropwizard.web.conf;
22

33
import com.fasterxml.jackson.annotation.JsonProperty;
4-
import io.dropwizard.jetty.setup.ServletEnvironment;
54
import io.dropwizard.core.setup.Environment;
65
import io.dropwizard.util.Duration;
76

8-
import java.util.Collections;
9-
import java.util.EnumSet;
107
import java.util.HashMap;
118
import java.util.List;
129
import java.util.Map;
1310

14-
import jakarta.servlet.DispatcherType;
15-
import jakarta.servlet.FilterRegistration;
16-
import org.eclipse.jetty.ee10.servlets.CrossOriginFilter;
11+
import java.util.Set;
12+
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
13+
import org.eclipse.jetty.http.pathmap.PathSpec;
14+
import org.eclipse.jetty.server.handler.CrossOriginHandler;
15+
import org.eclipse.jetty.server.handler.PathMappingsHandler;
1716

1817
public class CorsFilterFactory {
1918
@JsonProperty
@@ -98,45 +97,46 @@ public void setChainPreflight(boolean chainPreflight) {
9897
}
9998

10099
public void build(Environment environment, String urlPattern) {
101-
// build map of init parameters
102-
final Map<String, String> builder = new HashMap<>();
100+
101+
final CrossOriginHandler corsHandler = new CrossOriginHandler();
103102

104103
if (allowedOrigins != null && !allowedOrigins.isEmpty()) {
105-
builder.put(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, String.join(",", allowedOrigins));
104+
corsHandler.setAllowedOriginPatterns(Set.copyOf(allowedOrigins));
106105
}
107106

108107
if (allowedTimingOrigins != null && !allowedTimingOrigins.isEmpty()) {
109-
builder.put(CrossOriginFilter.ALLOWED_TIMING_ORIGINS_PARAM, String.join(",", allowedTimingOrigins));
108+
corsHandler.setAllowedTimingOriginPatterns(Set.copyOf(allowedTimingOrigins));
110109
}
111110

112111
if (allowedMethods != null && !allowedMethods.isEmpty()) {
113-
builder.put(CrossOriginFilter.ALLOWED_METHODS_PARAM, String.join(",", allowedMethods));
112+
corsHandler.setAllowedMethods(Set.copyOf(allowedMethods));
114113
}
115114

116115
if (allowedHeaders != null && !allowedHeaders.isEmpty()) {
117-
builder.put(CrossOriginFilter.ALLOWED_HEADERS_PARAM, String.join(",", allowedHeaders));
116+
corsHandler.setAllowedHeaders(Set.copyOf(allowedHeaders));
118117
}
119118

120119
if (preflightMaxAge != null) {
121-
builder.put(CrossOriginFilter.PREFLIGHT_MAX_AGE_PARAM, String.valueOf(preflightMaxAge.toSeconds()));
120+
corsHandler.setPreflightMaxAge(preflightMaxAge.toJavaDuration());
122121
}
123122

124123
if (allowCredentials != null) {
125-
builder.put(CrossOriginFilter.ALLOW_CREDENTIALS_PARAM, String.valueOf(allowCredentials));
124+
corsHandler.setAllowCredentials(allowCredentials);
126125
}
127126

128127
if (exposedHeaders != null && !exposedHeaders.isEmpty()) {
129-
builder.put(CrossOriginFilter.EXPOSED_HEADERS_PARAM, String.join(",", exposedHeaders));
128+
setExposedHeaders(exposedHeaders);
130129
}
131130

132131
if (chainPreflight != null) {
133-
builder.put(CrossOriginFilter.CHAIN_PREFLIGHT_PARAM, String.valueOf(chainPreflight));
132+
setChainPreflight(chainPreflight);
134133
}
135134

136-
// configure filter
137-
final ServletEnvironment servlets = environment.servlets();
138-
final FilterRegistration.Dynamic cors = servlets.addFilter("cross-origin-filter", CrossOriginFilter.class);
139-
cors.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, urlPattern);
140-
cors.setInitParameters(Collections.unmodifiableMap(builder));
135+
final PathMappingsHandler pathHandler = new PathMappingsHandler();
136+
pathHandler.addMapping(PathSpec.from(urlPattern), corsHandler);
137+
138+
final ServletContextHandler contextHandler = environment.getApplicationContext();
139+
contextHandler.insertHandler(corsHandler);
140+
141141
}
142142
}
Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,22 @@
11
package io.dropwizard.web.conf;
22

33
import com.google.common.collect.ImmutableList;
4-
import io.dropwizard.jetty.setup.ServletEnvironment;
4+
import io.dropwizard.jetty.MutableServletContextHandler;
55
import io.dropwizard.core.setup.Environment;
6-
import org.eclipse.jetty.ee10.servlets.CrossOriginFilter;
7-
import org.hamcrest.CoreMatchers;
6+
import java.util.Set;
7+
import org.eclipse.jetty.server.Handler;
8+
import org.eclipse.jetty.server.handler.CrossOriginHandler;
89
import org.junit.jupiter.api.BeforeEach;
910
import org.junit.jupiter.api.Test;
1011
import org.mockito.ArgumentCaptor;
1112
import org.mockito.Captor;
1213
import org.mockito.Mock;
1314
import org.mockito.MockitoAnnotations;
1415

15-
import java.util.Collections;
16-
import java.util.EnumSet;
17-
import java.util.Map;
18-
19-
import jakarta.servlet.Filter;
20-
import jakarta.servlet.FilterRegistration;
21-
2216
import static org.hamcrest.MatcherAssert.assertThat;
23-
import static org.hamcrest.Matchers.aMapWithSize;
24-
import static org.hamcrest.Matchers.hasEntry;
2517
import static org.hamcrest.Matchers.is;
26-
18+
import static org.hamcrest.Matchers.notNullValue;
2719
import static org.mockito.ArgumentMatchers.any;
28-
import static org.mockito.ArgumentMatchers.anyBoolean;
29-
import static org.mockito.ArgumentMatchers.anyMap;
30-
import static org.mockito.ArgumentMatchers.anyString;
31-
import static org.mockito.ArgumentMatchers.eq;
3220
import static org.mockito.Mockito.doNothing;
3321
import static org.mockito.Mockito.verify;
3422
import static org.mockito.Mockito.when;
@@ -37,30 +25,20 @@ public class CorsFilterFactoryTest {
3725
@Mock
3826
Environment env;
3927
@Mock
40-
ServletEnvironment servlets;
41-
@Mock
42-
FilterRegistration.Dynamic registration;
43-
@Captor
44-
ArgumentCaptor<String> filterNameCaptor;
28+
MutableServletContextHandler contextHandler;
4529
@Captor
46-
ArgumentCaptor<Class<? extends Filter>> filterClassCaptor;
47-
@Captor
48-
ArgumentCaptor<String> urlPatternCaptor;
49-
@Captor
50-
ArgumentCaptor<Map<String, String>> initParamsCaptor;
30+
ArgumentCaptor<Handler.Singleton> handlerCaptor;
5131

5232
@BeforeEach
5333
public void setUp() throws Exception {
5434
MockitoAnnotations.initMocks(this);
5535
}
5636

5737
@Test
58-
public void configureFilter() {
38+
public void configureHandler() {
5939
// given
60-
when(env.servlets()).thenReturn(servlets);
61-
when(servlets.addFilter(anyString(), any(Class.class))).thenReturn(registration);
62-
doNothing().when(registration).addMappingForUrlPatterns(any(EnumSet.class), anyBoolean(), anyString());
63-
when(registration.setInitParameters(anyMap())).thenReturn(Collections.emptySet());
40+
when(env.getApplicationContext()).thenReturn(contextHandler);
41+
doNothing().when(contextHandler).insertHandler(any(Handler.Singleton.class));
6442
String urlPattern = "/example/*";
6543
CorsFilterFactory factory = new CorsFilterFactory();
6644
factory.setAllowedOrigins(ImmutableList.of("example.com", "foo.com"));
@@ -69,13 +47,13 @@ public void configureFilter() {
6947
factory.build(env, urlPattern);
7048

7149
// then
72-
verify(servlets).addFilter(filterNameCaptor.capture(), filterClassCaptor.capture());
73-
verify(registration).addMappingForUrlPatterns(any(), eq(true), urlPatternCaptor.capture());
74-
verify(registration).setInitParameters(initParamsCaptor.capture());
75-
assertThat(filterNameCaptor.getValue(), is("cross-origin-filter"));
76-
assertThat(filterClassCaptor.getValue(), is(CoreMatchers.<Class<?>>equalTo(CrossOriginFilter.class)));
77-
assertThat(urlPatternCaptor.getValue(), is(urlPattern));
78-
assertThat(initParamsCaptor.getValue(), aMapWithSize(1));
79-
assertThat(initParamsCaptor.getValue(), hasEntry(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, "example.com,foo.com"));
50+
verify(contextHandler).insertHandler(handlerCaptor.capture());
51+
assertThat(handlerCaptor.getValue(), is(notNullValue()));
52+
assertThat(handlerCaptor.getValue() instanceof CrossOriginHandler, is(true));
53+
assertThat(((CrossOriginHandler) handlerCaptor.getValue())
54+
.getAllowedOriginPatterns()
55+
, is(Set.of("example.com", "foo.com")));
56+
57+
8058
}
8159
}

0 commit comments

Comments
 (0)