Skip to content

Commit 50adc61

Browse files
authored
Merge pull request #120 from networknt/issue119
fixes #119 Add cors handler to handle the in-flight
2 parents a0b11af + 83e02b8 commit 50adc61

File tree

10 files changed

+491
-3
lines changed

10 files changed

+491
-3
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@
282282
<artifactId>router-config</artifactId>
283283
<version>${version.light-4j}</version>
284284
</dependency>
285+
<dependency>
286+
<groupId>com.networknt</groupId>
287+
<artifactId>cors-config</artifactId>
288+
<version>${version.light-4j}</version>
289+
</dependency>
285290
<dependency>
286291
<groupId>com.networknt</groupId>
287292
<artifactId>caffeine-cache</artifactId>

src/main/java/com/networknt/aws/lambda/handler/Handler.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import java.lang.reflect.InvocationTargetException;
1616
import java.util.*;
1717

18-
import static io.undertow.util.PathTemplateMatch.ATTACHMENT_KEY;
19-
2018
public class Handler {
2119

2220

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package com.networknt.aws.lambda.handler.middleware.cors;
2+
3+
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
4+
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
5+
import com.networknt.aws.lambda.LightLambdaExchange;
6+
import com.networknt.aws.lambda.handler.MiddlewareHandler;
7+
import com.networknt.config.Config;
8+
import com.networknt.cors.CorsConfig;
9+
import com.networknt.status.Status;
10+
import com.networknt.utility.MapUtil;
11+
import com.networknt.utility.ModuleRegistry;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
import java.util.*;
16+
17+
import static com.networknt.cors.CorsHeaders.*;
18+
import static com.networknt.cors.CorsUtil.sanitizeDefaultPort;
19+
20+
/**
21+
* This middleware is responsible for adding CORS headers to the response and return it if the request
22+
* is a CORS preflight request. If the preflight request has the correct origin and method, it will return
23+
* 200 with the correct headers. Otherwise, 403 will be returned.
24+
*
25+
* If the request is a normal request with origin header and the origin is not matched, it will return 403.
26+
*
27+
* @author Steve Hu
28+
*
29+
*/
30+
public class RequestCorsMiddleware implements MiddlewareHandler {
31+
static CorsConfig CONFIG;
32+
private static final Logger LOG = LoggerFactory.getLogger(RequestCorsMiddleware.class);
33+
private static final String SUC10200 = "SUC10200";
34+
private static final String CORS_PREFLIGHT_REQUEST_FAILED = "ERR10092";
35+
36+
private List<String> allowedOrigins;
37+
private List<String> allowedMethods;
38+
private static final String ONE_HOUR_IN_SECONDS = "3600";
39+
40+
public RequestCorsMiddleware() {
41+
CONFIG = CorsConfig.load();
42+
allowedOrigins = CONFIG.getAllowedOrigins();
43+
allowedMethods = CONFIG.getAllowedMethods();
44+
LOG.info("RequestCorsMiddleware is constructed");
45+
}
46+
47+
public RequestCorsMiddleware(CorsConfig cfg) {
48+
CONFIG = cfg;
49+
allowedOrigins = CONFIG.getAllowedOrigins();
50+
allowedMethods = CONFIG.getAllowedMethods();
51+
LOG.info("RequestCorsMiddleware is constructed");
52+
}
53+
54+
@Override
55+
public Status execute(LightLambdaExchange exchange) {
56+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware.executeMiddleware starts.");
57+
if (!CONFIG.isEnabled()) {
58+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware is not enabled.");
59+
return disabledMiddlewareStatus();
60+
}
61+
APIGatewayProxyRequestEvent requestEvent = exchange.getRequest();
62+
if(requestEvent != null) {
63+
if(LOG.isTraceEnabled()) LOG.trace("Request event is not null.");
64+
Map<String, String> requestHeaders = requestEvent.getHeaders();
65+
if(isCorsRequest(requestHeaders)) {
66+
// set the allowed origins and methods based on the path prefix.
67+
if (CONFIG.getPathPrefixAllowed() != null) {
68+
String requestPath = requestEvent.getPath();
69+
for(Map.Entry<String, Object> entry: CONFIG.getPathPrefixAllowed().entrySet()) {
70+
if (requestPath.startsWith(entry.getKey())) {
71+
Map endpointCorsMap = (Map) entry.getValue();
72+
allowedOrigins = (List<String>) endpointCorsMap.get(CorsConfig.ALLOWED_ORIGINS);
73+
allowedMethods = (List<String>) endpointCorsMap.get(CorsConfig.ALLOWED_METHODS);
74+
break;
75+
}
76+
}
77+
}
78+
// if it is a preflight request, then handle it and return.
79+
if (isPreflightedRequest(requestEvent.getHttpMethod())) {
80+
// it is a preflight request. Handle it and return the response.
81+
return handlePreflightRequest(exchange, allowedOrigins, allowedMethods);
82+
} else {
83+
// normal request with origin header. check the origin and reject if it is not matched.
84+
String origin = matchOrigin(requestEvent, allowedOrigins);
85+
if(origin == null) {
86+
return new Status(CORS_PREFLIGHT_REQUEST_FAILED);
87+
}
88+
}
89+
}
90+
}
91+
// need to set the response header for the normal cors request that passed the origin check. It needs
92+
// to be set in the response chain instead of request chain.
93+
94+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware.executeMiddleware ends.");
95+
return successMiddlewareStatus();
96+
}
97+
98+
@Override
99+
public boolean isEnabled() {
100+
return CONFIG.isEnabled();
101+
}
102+
103+
@Override
104+
public void register() {
105+
ModuleRegistry.registerModule(
106+
CorsConfig.CONFIG_NAME,
107+
RequestCorsMiddleware.class.getName(),
108+
Config.getNoneDecryptedInstance().getJsonMapConfigNoCache(CorsConfig.CONFIG_NAME),
109+
null
110+
);
111+
}
112+
113+
@Override
114+
public void reload() {
115+
116+
}
117+
118+
@Override
119+
public boolean isAsynchronous() {
120+
return false;
121+
}
122+
123+
@Override
124+
public boolean isContinueOnFailure() {
125+
return false;
126+
}
127+
128+
@Override
129+
public boolean isAudited() {
130+
return false;
131+
}
132+
133+
@Override
134+
public void getCachedConfigurations() {
135+
136+
}
137+
138+
private Status handlePreflightRequest(LightLambdaExchange exchange, List<String> allowedOrigins, List<String> allowedMethods) {
139+
APIGatewayProxyResponseEvent responseEvent = new APIGatewayProxyResponseEvent();
140+
Map<String, String> requestHeaders = exchange.getRequest().getHeaders();
141+
Map<String, String> responseHeaders = new HashMap<>();
142+
if (MapUtil.getValueIgnoreCase(requestHeaders, ORIGIN).isPresent()) {
143+
if(matchOrigin(exchange.getRequest(), allowedOrigins) != null) {
144+
responseHeaders.put(ACCESS_CONTROL_ALLOW_ORIGIN, MapUtil.getValueIgnoreCase(requestHeaders, ORIGIN).get());
145+
responseHeaders.put("Vary", "Origin");
146+
} else {
147+
responseEvent.setHeaders(responseHeaders);
148+
responseEvent.setStatusCode(403);
149+
exchange.setInitialResponse(responseEvent);
150+
return new Status(CORS_PREFLIGHT_REQUEST_FAILED);
151+
}
152+
}
153+
responseHeaders.put(ACCESS_CONTROL_ALLOW_METHODS, convertToString(allowedMethods));
154+
Optional<String> acRequestHeaders = MapUtil.getValueIgnoreCase(requestHeaders, ACCESS_CONTROL_REQUEST_HEADERS);
155+
if (acRequestHeaders.isPresent()) {
156+
responseHeaders.put(ACCESS_CONTROL_ALLOW_HEADERS, acRequestHeaders.get());
157+
} else {
158+
responseHeaders.put(ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, WWW-Authenticate, Authorization");
159+
}
160+
responseHeaders.put(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
161+
responseHeaders.put(ACCESS_CONTROL_MAX_AGE, ONE_HOUR_IN_SECONDS);
162+
responseEvent.setHeaders(responseHeaders);
163+
responseEvent.setStatusCode(200);
164+
exchange.setInitialResponse(responseEvent);
165+
return new Status(SUC10200);
166+
}
167+
168+
/**
169+
* Match the Origin header with the allowed origins.
170+
* If it doesn't match then a 403 response code is set on the response and it returns null.
171+
* @param requestEvent the current request event.
172+
* @param allowedOrigins list of sanitized allowed origins.
173+
* @return the first matching origin, null otherwise.
174+
*/
175+
static String matchOrigin(APIGatewayProxyRequestEvent requestEvent, Collection<String> allowedOrigins) {
176+
Map<String, String> requestHeaders = requestEvent.getHeaders();
177+
Optional<String> optionalOrigin = MapUtil.getValueIgnoreCase(requestHeaders, ORIGIN);
178+
String origin = optionalOrigin.orElse(null);
179+
if(LOG.isTraceEnabled()) LOG.trace("origin from the request header = {} allowedOrigins = {}", origin, allowedOrigins);
180+
if (origin != null && allowedOrigins != null && !allowedOrigins.isEmpty()) {
181+
for (String allowedOrigin : allowedOrigins) {
182+
if (allowedOrigin.equalsIgnoreCase(sanitizeDefaultPort(origin))) {
183+
return allowedOrigin;
184+
}
185+
}
186+
}
187+
LOG.debug("Request rejected due to HOST/ORIGIN mis-match.");
188+
return null;
189+
}
190+
191+
static boolean isCorsRequest(Map<String, String> requestHeaders) {
192+
// all cors request will have origin header regardless it is a preflight request
193+
// or normal request with method other than OPTIONS.
194+
return MapUtil.getValueIgnoreCase(requestHeaders, ORIGIN).isPresent();
195+
}
196+
197+
static boolean isPreflightedRequest(String requestMethod) {
198+
// only the preflight request will have OPTIONS method, and it should be checked
199+
// after it is confirmed a cors request with origin header.
200+
return "OPTIONS".equalsIgnoreCase(requestMethod);
201+
}
202+
203+
static String convertToString(List<String> list) {
204+
return String.join(",", list);
205+
}
206+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package com.networknt.aws.lambda.handler.middleware.cors;
2+
3+
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
4+
import com.networknt.aws.lambda.LightLambdaExchange;
5+
import com.networknt.aws.lambda.handler.MiddlewareHandler;
6+
import com.networknt.config.Config;
7+
import com.networknt.cors.CorsConfig;
8+
import com.networknt.status.Status;
9+
import com.networknt.utility.MapUtil;
10+
import com.networknt.utility.ModuleRegistry;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
14+
/**
15+
* This middleware is responsible for adding CORS headers to the response if the request is a CORS request.
16+
* It means the request has a header Origin. The middleware will add the following headers to the response:
17+
* Access-Control-Allow-Origin: The value of the Origin header
18+
* Access-Control-Allow-Methods: The value of the Access-Control-Request-Method header
19+
*
20+
* @author Steve Hu
21+
*/
22+
public class ResponseCorsMiddleware implements MiddlewareHandler {
23+
24+
static CorsConfig CONFIG;
25+
private static final Logger LOG = LoggerFactory.getLogger(ResponseCorsMiddleware.class);
26+
27+
public ResponseCorsMiddleware() {
28+
CONFIG = CorsConfig.load();
29+
LOG.info("ResponseCorsMiddleware is constructed");
30+
}
31+
32+
public ResponseCorsMiddleware(CorsConfig cfg) {
33+
CONFIG = cfg;
34+
LOG.info("ResponseCorsMiddleware is constructed");
35+
}
36+
37+
@Override
38+
public Status execute(LightLambdaExchange exchange) {
39+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware.executeMiddleware starts.");
40+
if (!CONFIG.isEnabled()) {
41+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware is not enabled.");
42+
return disabledMiddlewareStatus();
43+
}
44+
APIGatewayProxyResponseEvent responseEvent = exchange.getResponse();
45+
if(responseEvent != null) {
46+
if (LOG.isTraceEnabled()) LOG.trace("Response event is not null.");
47+
var responseHeaders = responseEvent.getHeaders();
48+
if (responseHeaders != null) {
49+
if (LOG.isTraceEnabled()) LOG.trace("Response headers is not null.");
50+
if (MapUtil.getValueIgnoreCase(exchange.getReadOnlyRequest().getHeaders(), "Origin").isPresent()) {
51+
// this is a CORS request, and it is passed the CORS check in the RequestCorsMiddleware.
52+
responseHeaders.put("Access-Control-Allow-Origin", exchange.getReadOnlyRequest().getHeaders().get("Origin"));
53+
}
54+
}
55+
}
56+
if(LOG.isTraceEnabled()) LOG.trace("RequestCorsMiddleware.executeMiddleware ends.");
57+
return successMiddlewareStatus();
58+
}
59+
60+
@Override
61+
public boolean isEnabled() {
62+
return CONFIG.isEnabled();
63+
}
64+
65+
@Override
66+
public void register() {
67+
ModuleRegistry.registerModule(
68+
CorsConfig.CONFIG_NAME,
69+
ResponseCorsMiddleware.class.getName(),
70+
Config.getNoneDecryptedInstance().getJsonMapConfigNoCache(CorsConfig.CONFIG_NAME),
71+
null
72+
);
73+
}
74+
75+
@Override
76+
public void reload() {
77+
}
78+
79+
@Override
80+
public boolean isAsynchronous() {
81+
return false;
82+
}
83+
84+
@Override
85+
public boolean isResponseMiddleware() {
86+
return true;
87+
}
88+
89+
@Override
90+
public boolean isContinueOnFailure() {
91+
return false;
92+
}
93+
94+
@Override
95+
public boolean isAudited() {
96+
return false;
97+
}
98+
99+
@Override
100+
public void getCachedConfigurations() {
101+
}
102+
103+
}

src/main/resources/config/handler.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ chains:
5050
- limit
5151
- traceability
5252
- correlation
53+
- requestCors
5354
- requestHeader
5455
- requestTransformer
5556
- audit
@@ -61,6 +62,7 @@ chains:
6162
- validator
6263
response:
6364
- responseHeader
65+
- responseCors
6466
- responseTransformer
6567

6668
admin:

src/test/java/com/networknt/aws/lambda/handler/HandlerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ public class HandlerTest {
88
@Test
99
public void testInitHandler() {
1010
Handler.init();
11-
Assertions.assertEquals(25, Handler.getHandlers().size());
11+
Assertions.assertEquals(27, Handler.getHandlers().size());
1212
}
1313
}

0 commit comments

Comments
 (0)