11//! Response types
22
3- use crate :: body:: Body ;
3+ use crate :: { body:: Body , request :: RequestOrigin } ;
44use http:: {
5- header:: { HeaderMap , HeaderValue , CONTENT_TYPE } ,
5+ header:: { HeaderMap , HeaderValue , CONTENT_TYPE , SET_COOKIE } ,
66 Response ,
77} ;
88use serde:: {
9- ser:: { Error as SerError , SerializeMap } ,
9+ ser:: { Error as SerError , SerializeMap , SerializeSeq } ,
1010 Serialize , Serializer ,
1111} ;
1212
13- /// Representation of API Gateway response
13+ /// Representation of Lambda response
14+ #[ doc( hidden) ]
15+ #[ derive( Serialize , Debug ) ]
16+ #[ serde( untagged) ]
17+ pub enum LambdaResponse {
18+ ApiGatewayV2 ( ApiGatewayV2Response ) ,
19+ Alb ( AlbResponse ) ,
20+ ApiGateway ( ApiGatewayResponse ) ,
21+ }
22+
23+ /// Representation of API Gateway v2 lambda response
1424#[ doc( hidden) ]
1525#[ derive( Serialize , Debug ) ]
1626#[ serde( rename_all = "camelCase" ) ]
17- pub struct LambdaResponse {
18- pub status_code : u16 ,
19- // ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error
20- // when one is provided. only populate this for ALB responses
27+ pub struct ApiGatewayV2Response {
28+ status_code : u16 ,
29+ #[ serde( serialize_with = "serialize_headers" ) ]
30+ headers : HeaderMap < HeaderValue > ,
31+ #[ serde( serialize_with = "serialize_headers_slice" ) ]
32+ cookies : Vec < HeaderValue > ,
2133 #[ serde( skip_serializing_if = "Option::is_none" ) ]
22- pub status_description : Option < String > ,
34+ body : Option < Body > ,
35+ is_base64_encoded : bool ,
36+ }
37+
38+ /// Representation of ALB lambda response
39+ #[ doc( hidden) ]
40+ #[ derive( Serialize , Debug ) ]
41+ #[ serde( rename_all = "camelCase" ) ]
42+ pub struct AlbResponse {
43+ status_code : u16 ,
44+ status_description : String ,
2345 #[ serde( serialize_with = "serialize_headers" ) ]
24- pub headers : HeaderMap < HeaderValue > ,
25- #[ serde( serialize_with = "serialize_multi_value_headers" ) ]
26- pub multi_value_headers : HeaderMap < HeaderValue > ,
46+ headers : HeaderMap < HeaderValue > ,
2747 #[ serde( skip_serializing_if = "Option::is_none" ) ]
28- pub body : Option < Body > ,
29- // This field is optional for API Gateway but required for ALB
30- pub is_base64_encoded : bool ,
48+ body : Option < Body > ,
49+ is_base64_encoded : bool ,
3150}
3251
33- #[ cfg( test) ]
34- impl Default for LambdaResponse {
35- fn default ( ) -> Self {
36- Self {
37- status_code : 200 ,
38- status_description : Default :: default ( ) ,
39- headers : Default :: default ( ) ,
40- multi_value_headers : Default :: default ( ) ,
41- body : Default :: default ( ) ,
42- is_base64_encoded : Default :: default ( ) ,
43- }
44- }
52+ /// Representation of API Gateway lambda response
53+ #[ doc( hidden) ]
54+ #[ derive( Serialize , Debug ) ]
55+ #[ serde( rename_all = "camelCase" ) ]
56+ pub struct ApiGatewayResponse {
57+ status_code : u16 ,
58+ #[ serde( serialize_with = "serialize_headers" ) ]
59+ headers : HeaderMap < HeaderValue > ,
60+ #[ serde( serialize_with = "serialize_multi_value_headers" ) ]
61+ multi_value_headers : HeaderMap < HeaderValue > ,
62+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
63+ body : Option < Body > ,
64+ is_base64_encoded : bool ,
4565}
4666
4767/// Serialize a http::HeaderMap into a serde str => str map
7393 map. end ( )
7494}
7595
96+ /// Serialize a &[HeaderValue] into a Vec<str>
97+ fn serialize_headers_slice < S > ( headers : & [ HeaderValue ] , serializer : S ) -> Result < S :: Ok , S :: Error >
98+ where
99+ S : Serializer ,
100+ {
101+ let mut seq = serializer. serialize_seq ( Some ( headers. len ( ) ) ) ?;
102+ for header in headers {
103+ seq. serialize_element ( header. to_str ( ) . map_err ( S :: Error :: custom) ?) ?;
104+ }
105+ seq. end ( )
106+ }
107+
76108/// tranformation from http type to internal type
77109impl LambdaResponse {
78- pub ( crate ) fn from_response < T > ( is_alb : bool , value : Response < T > ) -> Self
110+ pub ( crate ) fn from_response < T > ( request_origin : & RequestOrigin , value : Response < T > ) -> Self
79111 where
80112 T : Into < Body > ,
81113 {
@@ -85,21 +117,43 @@ impl LambdaResponse {
85117 b @ Body :: Text ( _) => ( false , Some ( b) ) ,
86118 b @ Body :: Binary ( _) => ( true , Some ( b) ) ,
87119 } ;
88- Self {
89- status_code : parts. status . as_u16 ( ) ,
90- status_description : if is_alb {
91- Some ( format ! (
120+
121+ let mut headers = parts. headers ;
122+ let status_code = parts. status . as_u16 ( ) ;
123+
124+ match request_origin {
125+ RequestOrigin :: ApiGatewayV2 => {
126+ // ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
127+ // so remove them from the headers.
128+ let cookies: Vec < HeaderValue > = headers. get_all ( SET_COOKIE ) . iter ( ) . cloned ( ) . collect ( ) ;
129+ headers. remove ( SET_COOKIE ) ;
130+
131+ LambdaResponse :: ApiGatewayV2 ( ApiGatewayV2Response {
132+ body,
133+ status_code,
134+ is_base64_encoded,
135+ cookies,
136+ headers,
137+ } )
138+ }
139+ RequestOrigin :: ApiGateway => LambdaResponse :: ApiGateway ( ApiGatewayResponse {
140+ body,
141+ status_code,
142+ is_base64_encoded,
143+ headers : headers. clone ( ) ,
144+ multi_value_headers : headers,
145+ } ) ,
146+ RequestOrigin :: Alb => LambdaResponse :: Alb ( AlbResponse {
147+ body,
148+ status_code,
149+ is_base64_encoded,
150+ headers,
151+ status_description : format ! (
92152 "{} {}" ,
93- parts . status . as_u16 ( ) ,
153+ status_code ,
94154 parts. status. canonical_reason( ) . unwrap_or_default( )
95- ) )
96- } else {
97- None
98- } ,
99- body,
100- headers : parts. headers . clone ( ) ,
101- multi_value_headers : parts. headers ,
102- is_base64_encoded,
155+ ) ,
156+ } ) ,
103157 }
104158 }
105159}
@@ -159,10 +213,42 @@ impl IntoResponse for serde_json::Value {
159213
160214#[ cfg( test) ]
161215mod tests {
162- use super :: { Body , IntoResponse , LambdaResponse } ;
216+ use super :: {
217+ AlbResponse , ApiGatewayResponse , ApiGatewayV2Response , Body , IntoResponse , LambdaResponse , RequestOrigin ,
218+ } ;
163219 use http:: { header:: CONTENT_TYPE , Response } ;
164220 use serde_json:: { self , json} ;
165221
222+ fn api_gateway_response ( ) -> ApiGatewayResponse {
223+ ApiGatewayResponse {
224+ status_code : 200 ,
225+ headers : Default :: default ( ) ,
226+ multi_value_headers : Default :: default ( ) ,
227+ body : Default :: default ( ) ,
228+ is_base64_encoded : Default :: default ( ) ,
229+ }
230+ }
231+
232+ fn alb_response ( ) -> AlbResponse {
233+ AlbResponse {
234+ status_code : 200 ,
235+ status_description : "200 OK" . to_string ( ) ,
236+ headers : Default :: default ( ) ,
237+ body : Default :: default ( ) ,
238+ is_base64_encoded : Default :: default ( ) ,
239+ }
240+ }
241+
242+ fn api_gateway_v2_response ( ) -> ApiGatewayV2Response {
243+ ApiGatewayV2Response {
244+ status_code : 200 ,
245+ headers : Default :: default ( ) ,
246+ body : Default :: default ( ) ,
247+ cookies : Default :: default ( ) ,
248+ is_base64_encoded : Default :: default ( ) ,
249+ }
250+ }
251+
166252 #[ test]
167253 fn json_into_response ( ) {
168254 let response = json ! ( { "hello" : "lambda" } ) . into_response ( ) ;
@@ -189,32 +275,39 @@ mod tests {
189275 }
190276
191277 #[ test]
192- fn default_response ( ) {
193- assert_eq ! ( LambdaResponse :: default ( ) . status_code, 200 )
278+ fn serialize_body_for_api_gateway ( ) {
279+ let mut resp = api_gateway_response ( ) ;
280+ resp. body = Some ( "foo" . into ( ) ) ;
281+ assert_eq ! (
282+ serde_json:: to_string( & resp) . expect( "failed to serialize response" ) ,
283+ r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
284+ ) ;
194285 }
195286
196287 #[ test]
197- fn serialize_default ( ) {
288+ fn serialize_body_for_alb ( ) {
289+ let mut resp = alb_response ( ) ;
290+ resp. body = Some ( "foo" . into ( ) ) ;
198291 assert_eq ! (
199- serde_json:: to_string( & LambdaResponse :: default ( ) ) . expect( "failed to serialize response" ) ,
200- r#"{"statusCode":200,"headers":{},"multiValueHeaders":{} ,"isBase64Encoded":false}"#
292+ serde_json:: to_string( & resp ) . expect( "failed to serialize response" ) ,
293+ r#"{"statusCode":200,"statusDescription":"200 OK"," headers":{},"body":"foo" ,"isBase64Encoded":false}"#
201294 ) ;
202295 }
203296
204297 #[ test]
205- fn serialize_body ( ) {
206- let mut resp = LambdaResponse :: default ( ) ;
298+ fn serialize_body_for_api_gateway_v2 ( ) {
299+ let mut resp = api_gateway_v2_response ( ) ;
207300 resp. body = Some ( "foo" . into ( ) ) ;
208301 assert_eq ! (
209302 serde_json:: to_string( & resp) . expect( "failed to serialize response" ) ,
210- r#"{"statusCode":200,"headers":{},"multiValueHeaders":{} ,"body":"foo","isBase64Encoded":false}"#
303+ r#"{"statusCode":200,"headers":{},"cookies":[] ,"body":"foo","isBase64Encoded":false}"#
211304 ) ;
212305 }
213306
214307 #[ test]
215308 fn serialize_multi_value_headers ( ) {
216309 let res = LambdaResponse :: from_response (
217- false ,
310+ & RequestOrigin :: ApiGateway ,
218311 Response :: builder ( )
219312 . header ( "multi" , "a" )
220313 . header ( "multi" , "b" )
@@ -227,4 +320,21 @@ mod tests {
227320 r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
228321 )
229322 }
323+
324+ #[ test]
325+ fn serialize_cookies ( ) {
326+ let res = LambdaResponse :: from_response (
327+ & RequestOrigin :: ApiGatewayV2 ,
328+ Response :: builder ( )
329+ . header ( "set-cookie" , "cookie1=a" )
330+ . header ( "set-cookie" , "cookie2=b" )
331+ . body ( Body :: from ( ( ) ) )
332+ . expect ( "failed to create response" ) ,
333+ ) ;
334+ let json = serde_json:: to_string ( & res) . expect ( "failed to serialize to json" ) ;
335+ assert_eq ! (
336+ json,
337+ r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"#
338+ )
339+ }
230340}
0 commit comments