Skip to content

Commit 8fef6b5

Browse files
committed
#1084: decode body if base64 is enable
1 parent e6abc1f commit 8fef6b5

File tree

2 files changed

+89
-28
lines changed

2 files changed

+89
-28
lines changed

aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
package com.amazonaws.serverless.proxy.spring;
22

33
import java.io.InputStream;
4+
import java.nio.charset.Charset;
45
import java.nio.charset.StandardCharsets;
5-
import java.util.Iterator;
6+
import java.nio.charset.UnsupportedCharsetException;
7+
import java.util.Base64;
68
import java.util.Map;
79
import java.util.Map.Entry;
8-
import java.util.Set;
910
import java.util.concurrent.CountDownLatch;
1011
import java.util.concurrent.TimeUnit;
1112

13+
import org.apache.commons.io.Charsets;
1214
import org.apache.commons.logging.Log;
1315
import org.apache.commons.logging.LogFactory;
1416
import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest;
1517
import org.springframework.cloud.function.serverless.web.ServerlessMVC;
18+
import org.springframework.http.HttpHeaders;
1619
import org.springframework.util.CollectionUtils;
1720
import org.springframework.util.FileCopyUtils;
1821
import org.springframework.util.MultiValueMapAdapter;
1922
import org.springframework.util.StringUtils;
2023

21-
import com.amazonaws.serverless.proxy.AsyncInitializationWrapper;
2224
import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter;
2325
import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter;
2426
import com.amazonaws.serverless.proxy.RequestReader;
@@ -120,10 +122,14 @@ private static HttpServletRequest generateRequest1(String request, Context lambd
120122
MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders());
121123
httpRequest.setHeaders(headers);
122124
}
123-
if (StringUtils.hasText(v1Request.getBody())) {
124-
httpRequest.setContentType("application/json");
125-
httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8));
126-
}
125+
if (StringUtils.hasText(v1Request.getBody())) {
126+
if (v1Request.isBase64Encoded()) {
127+
httpRequest.setContent(Base64.getMimeDecoder().decode(v1Request.getBody()));
128+
} else {
129+
Charset charseEncoding = parseCharacterEncoding(v1Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
130+
httpRequest.setContent(v1Request.getBody().getBytes(charseEncoding));
131+
}
132+
}
127133
if (v1Request.getRequestContext() != null) {
128134
httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext());
129135
httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb());
@@ -149,11 +155,15 @@ private static HttpServletRequest generateRequest2(String request, Context lambd
149155
populateQueryStringparameters(v2Request.getQueryStringParameters(), httpRequest);
150156

151157
v2Request.getHeaders().forEach(httpRequest::setHeader);
152-
153-
if (StringUtils.hasText(v2Request.getBody())) {
154-
httpRequest.setContentType("application/json");
155-
httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8));
156-
}
158+
159+
if (StringUtils.hasText(v2Request.getBody())) {
160+
if (v2Request.isBase64Encoded()) {
161+
httpRequest.setContent(Base64.getMimeDecoder().decode(v2Request.getBody()));
162+
} else {
163+
Charset charseEncoding = parseCharacterEncoding(v2Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
164+
httpRequest.setContent(v2Request.getBody().getBytes(charseEncoding));
165+
}
166+
}
157167
httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext());
158168
httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables());
159169
httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request);
@@ -180,4 +190,36 @@ private static <T> T readValue(String json, Class<T> clazz, ObjectMapper mapper)
180190
}
181191
}
182192

193+
static final String HEADER_KEY_VALUE_SEPARATOR = "=";
194+
static final String HEADER_VALUE_SEPARATOR = ";";
195+
static final String ENCODING_VALUE_KEY = "charset";
196+
static protected Charset parseCharacterEncoding(String contentTypeHeader) {
197+
// we only look at content-type because content-encoding should only be used for
198+
// "binary" requests such as gzip/deflate.
199+
Charset defaultCharset = StandardCharsets.UTF_8;
200+
if (contentTypeHeader == null) {
201+
return defaultCharset;
202+
}
203+
204+
String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR);
205+
if (contentTypeValues.length <= 1) {
206+
return defaultCharset;
207+
}
208+
209+
for (String contentTypeValue : contentTypeValues) {
210+
if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) {
211+
String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR);
212+
if (encodingValues.length <= 1) {
213+
return defaultCharset;
214+
}
215+
try {
216+
return Charsets.toCharset(encodingValues[1]);
217+
} catch (UnsupportedCharsetException ex) {
218+
return defaultCharset;
219+
}
220+
}
221+
}
222+
return defaultCharset;
223+
}
224+
183225
}

aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringDelegatingLambdaContainerHandlerTests.java

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@
66
import java.io.ByteArrayOutputStream;
77
import java.io.InputStream;
88
import java.nio.charset.StandardCharsets;
9-
import java.util.Arrays;
10-
import java.util.Collection;
11-
import java.util.HashMap;
12-
import java.util.Map;
9+
import java.util.*;
1310

1411
import com.amazonaws.serverless.exceptions.ContainerInitializationException;
1512
import org.junit.jupiter.params.ParameterizedTest;
1613
import org.junit.jupiter.params.provider.MethodSource;
17-
import org.springframework.cloud.function.serverless.web.ServerlessServletContext;
1814
import org.springframework.util.CollectionUtils;
1915

2016
import com.amazonaws.serverless.proxy.spring.servletapp.MessageData;
@@ -214,7 +210,7 @@ public static Collection<String> data() {
214210
public void validateComplesrequest(String jsonEvent) throws Exception {
215211
initServletAppTest();
216212
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST",
217-
"/foo/male/list/24", "{\"name\":\"bob\"}", null));
213+
"/foo/male/list/24", "{\"name\":\"bob\"}", false,null));
218214
ByteArrayOutputStream output = new ByteArrayOutputStream();
219215
handler.handleRequest(targetStream, output, null);
220216
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
@@ -229,7 +225,7 @@ public void validateComplesrequest(String jsonEvent) throws Exception {
229225
@ParameterizedTest
230226
public void testAsyncPost(String jsonEvent) throws Exception {
231227
initServletAppTest();
232-
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}", null));
228+
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}",false, null));
233229
ByteArrayOutputStream output = new ByteArrayOutputStream();
234230
handler.handleRequest(targetStream, output, null);
235231
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
@@ -242,7 +238,7 @@ public void testAsyncPost(String jsonEvent) throws Exception {
242238
public void testValidate400(String jsonEvent) throws Exception {
243239
initServletAppTest();
244240
UserData ud = new UserData();
245-
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
241+
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
246242
ByteArrayOutputStream output = new ByteArrayOutputStream();
247243
handler.handleRequest(targetStream, output, null);
248244
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
@@ -258,27 +254,48 @@ public void testValidate200(String jsonEvent) throws Exception {
258254
ud.setFirstName("bob");
259255
ud.setLastName("smith");
260256
ud.setEmail("[email protected]");
261-
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
257+
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
262258
ByteArrayOutputStream output = new ByteArrayOutputStream();
263259
handler.handleRequest(targetStream, output, null);
264260
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
265261
assertEquals(200, result.get("statusCode"));
266262
assertEquals("VALID", result.get("body"));
267263
}
268264

265+
@MethodSource("data")
266+
@ParameterizedTest
267+
public void testValidate200Base64(String jsonEvent) throws Exception {
268+
initServletAppTest();
269+
UserData ud = new UserData();
270+
ud.setFirstName("bob");
271+
ud.setLastName("smith");
272+
ud.setEmail("[email protected]");
273+
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate",
274+
Base64.getMimeEncoder().encodeToString(mapper.writeValueAsString(ud).getBytes()),true, null));
275+
276+
ByteArrayOutputStream output = new ByteArrayOutputStream();
277+
handler.handleRequest(targetStream, output, null);
278+
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
279+
assertEquals(200, result.get("statusCode"));
280+
assertEquals("VALID", result.get("body"));
281+
}
282+
283+
269284
@MethodSource("data")
270285
@ParameterizedTest
271286
public void messageObject_parsesObject_returnsCorrectMessage(String jsonEvent) throws Exception {
272287
initServletAppTest();
273288
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
274-
mapper.writeValueAsString(new MessageData("test message")), null));
289+
mapper.writeValueAsString(new MessageData("test message")),false, null));
275290
ByteArrayOutputStream output = new ByteArrayOutputStream();
276291
handler.handleRequest(targetStream, output, null);
277292
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
278293
assertEquals(200, result.get("statusCode"));
279294
assertEquals("test message", result.get("body"));
280295
}
281296

297+
298+
282299
@SuppressWarnings({"unchecked" })
283300
@MethodSource("data")
284301
@ParameterizedTest
@@ -289,40 +306,42 @@ void messageObject_propertiesInContentType_returnsCorrectMessage(String jsonEven
289306
headers.put(HttpHeaders.CONTENT_TYPE, "application/json;v=1");
290307
headers.put(HttpHeaders.ACCEPT, "application/json;v=1");
291308
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
292-
mapper.writeValueAsString(new MessageData("test message")), headers));
309+
mapper.writeValueAsString(new MessageData("test message")),false, headers));
293310

294311
ByteArrayOutputStream output = new ByteArrayOutputStream();
295312
handler.handleRequest(targetStream, output, null);
296313
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
297314
assertEquals("test message", result.get("body"));
298315
}
299316

300-
private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body, Map headers) throws Exception {
317+
private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
301318
Map requestMap = mapper.readValue(jsonEvent, Map.class);
302319
if (requestMap.get("version").equals("2.0")) {
303-
return generateHttpRequest2(requestMap, method, path, body, headers);
320+
return generateHttpRequest2(requestMap, method, path, body, isBase64Encoded,headers);
304321
}
305-
return generateHttpRequest(requestMap, method, path, body, headers);
322+
return generateHttpRequest(requestMap, method, path, body,isBase64Encoded, headers);
306323
}
307324

308325
@SuppressWarnings({ "unchecked"})
309-
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body, Map headers) throws Exception {
326+
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
310327
requestMap.put("path", path);
311328
requestMap.put("httpMethod", method);
312329
requestMap.put("body", body);
330+
requestMap.put("isBase64Encoded", isBase64Encoded);
313331
if (!CollectionUtils.isEmpty(headers)) {
314332
requestMap.put("headers", headers);
315333
}
316334
return mapper.writeValueAsBytes(requestMap);
317335
}
318336

319337
@SuppressWarnings({ "unchecked"})
320-
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body, Map headers) throws Exception {
338+
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
321339
Map map = mapper.readValue(API_GATEWAY_EVENT_V2, Map.class);
322340
Map http = (Map) ((Map) map.get("requestContext")).get("http");
323341
http.put("path", path);
324342
http.put("method", method);
325343
map.put("body", body);
344+
map.put("isBase64Encoded", isBase64Encoded);
326345
if (!CollectionUtils.isEmpty(headers)) {
327346
map.put("headers", headers);
328347
}

0 commit comments

Comments
 (0)