@@ -14,21 +14,63 @@ use std::env;
1414/// This is the primary invocation point for the lambda and should do the heavy lifting
1515async fn function_handler ( event : LambdaEvent < SqsEvent > ) -> Result < ( ) , Error > {
1616 let table_uri = std:: env:: var ( "DELTA_TABLE_URI" ) . expect ( "Failed to get `DELTA_TABLE_URI`" ) ;
17- debug ! ( "payload received: {:?}" , event. payload. records) ;
1817
19- let records = match std:: env:: var ( "UNWRAP_SNS_ENVELOPE" ) {
20- Ok ( _) => unwrap_sns_payload ( & event. payload . records ) ,
21- Err ( _) => event. payload . records ,
22- } ;
18+ // Hold onto the instant that the function started in order to attempt to exit on time.
19+ let fn_start = std:: time:: Instant :: now ( ) ;
20+ trace ! ( "payload received: {:?}" , event. payload. records) ;
2321
24- let values = extract_json_from_records ( & records) ;
25- debug ! ( "JSON pulled out: {values:?}" ) ;
22+ let config = aws_config:: from_env ( ) . load ( ) . await ;
23+ let sqs_client = aws_sdk_sqs:: Client :: new ( & config) ;
24+ // How many more messages should sqs-ingest try to consume?
25+ let mut more_count = 0 ;
26+ // Millis to allow for consuming more messages
27+ let more_deadline_ms: u128 = ( event. context . deadline / 2 ) . into ( ) ;
28+ // records should contain the raw deserialized JSON payload that was sent through SQS. It
29+ // should be "fit" for writing to Delta
30+ let mut records: Vec < String > = vec ! [ ] ;
2631
27- if !values. is_empty ( ) {
32+ if let Ok ( how_many_more) = std:: env:: var ( "BUFFER_MORE_MESSAGES" ) {
33+ more_count = how_many_more
34+ . parse ( )
35+ . expect ( "The value of BUFFER_MORE_MESSAGES cannot be coerced into an int :thinking:" ) ;
36+ debug ! ( "sqs-ingest configured to consume an additional {more_count} messages from SQS" ) ;
37+ debug ! ( "sqs-ingest will attempt to retrieve {more_count} messages in no more than {more_deadline_ms}ms to avoid timing out the function" ) ;
38+ }
39+
40+ if more_count > 0 {
41+ if let Ok ( queue_url) = std:: env:: var ( "BUFFER_MORE_QUEUE_URL" ) {
42+ let mut completed = false ;
43+ let mut fetched = 0 ;
44+
45+ while !completed {
46+ let receive = sqs_client
47+ . receive_message ( )
48+ . max_number_of_messages ( 10 )
49+ . queue_url ( queue_url. clone ( ) )
50+ . send ( )
51+ . await ?;
52+ records. append ( & mut extract_json_from_sqs_direct (
53+ receive. messages . unwrap_or_default ( ) ,
54+ ) ) ;
55+
56+ fetched += 1 ;
57+ completed =
58+ ( fetched >= more_count) || ( fn_start. elapsed ( ) . as_millis ( ) >= more_deadline_ms) ;
59+ }
60+ } else {
61+ error ! ( "The function cannot buffer more messages without a BUFFER_MORE_QUEUE_URL! Only writing messages that triggered the Lambda" ) ;
62+ }
63+ }
64+
65+ // Add the messages that actually triggered this function invocation
66+ records. append ( & mut extract_json_from_records ( & event. payload . records ) ) ;
67+
68+ if !records. is_empty ( ) {
2869 let table = oxbow:: lock:: open_table ( & table_uri) . await ?;
29- match append_values ( table, values. as_slice ( ) ) . await {
70+
71+ match append_values ( table, records. as_slice ( ) ) . await {
3072 Ok ( table) => {
31- debug ! ( "Appended values to: {table:?}" ) ;
73+ debug ! ( "Appended {} values to: {table:?}" , records . len ( ) ) ;
3274 }
3375 Err ( e) => {
3476 error ! ( "Failed to append the values to configured Delta table: {e:?}" ) ;
@@ -39,6 +81,11 @@ async fn function_handler(event: LambdaEvent<SqsEvent>) -> Result<(), Error> {
3981 error ! ( "An empty payload was extracted which doesn't seem right!" ) ;
4082 }
4183
84+ debug ! (
85+ "sqs-ingest completed its work in {}ms" ,
86+ fn_start. elapsed( ) . as_millis( )
87+ ) ;
88+
4289 Ok ( ( ) )
4390}
4491
@@ -60,37 +107,53 @@ async fn main() -> Result<(), Error> {
60107 run ( service_fn ( function_handler) ) . await
61108}
62109
110+ /// Extract and deserialize the JSON from messages which were directly consumed from SQS rather
111+ /// than those received via Lambda triggering.
112+ ///
113+ /// This corresponds to the messages consumed from BUFFER_MORE_QUEUE_URL
114+ fn extract_json_from_sqs_direct ( messages : Vec < aws_sdk_sqs:: types:: Message > ) -> Vec < String > {
115+ if inside_sns ( ) {
116+ messages
117+ . iter ( )
118+ . filter ( |m| m. body ( ) . is_some ( ) )
119+ . map ( |m| m. body ( ) . as_ref ( ) . unwrap ( ) . to_string ( ) )
120+ . map ( |b| {
121+ let value: SNSWrapper =
122+ serde_json:: from_str ( & b) . expect ( "Failed to deserialize SNS payload as JSON" ) ;
123+ value. to_vec ( )
124+ } )
125+ . flatten ( )
126+ . collect :: < Vec < String > > ( )
127+ } else {
128+ messages
129+ . iter ( )
130+ . filter ( |m| m. body ( ) . is_some ( ) )
131+ . map ( |m| m. body ( ) . as_ref ( ) . unwrap ( ) . to_string ( ) )
132+ . collect :: < Vec < String > > ( )
133+ }
134+ }
135+
63136/// Convert the `body` payloads from [SqsMessage] entities into JSONL
64137/// which can be passed into the [oxbow::write::append_values] function
65138fn extract_json_from_records ( records : & [ SqsMessage ] ) -> Vec < String > {
66- records
67- . iter ( )
68- . filter ( |m| m. body . is_some ( ) )
69- . map ( |m| m. body . as_ref ( ) . unwrap ( ) . clone ( ) )
70- . collect :: < Vec < String > > ( )
71- }
72-
73- /// SNS cannot help but JSON encode all its payloads so sometimes we must unwrap it.
74- fn unwrap_sns_payload ( records : & [ SqsMessage ] ) -> Vec < SqsMessage > {
75- let mut unpacked = vec ! [ ] ;
76- for record in records {
77- if let Some ( body) = record. body . as_ref ( ) {
78- trace ! ( "Attempting to unwrap the contents of nested JSON: {body}" ) ;
79- let nested: SNSWrapper = serde_json:: from_str ( body) . expect (
80- "Failed to unpack SNS
81- messages, this could be a misconfiguration and there is no SNS envelope or raw_delivery has not
82- been set" ,
83- ) ;
84- for body in nested. records {
85- let message: SqsMessage = SqsMessage {
86- body : Some ( serde_json:: to_string ( & body) . expect ( "Failed to reserialize JSON" ) ) ,
87- ..Default :: default ( )
88- } ;
89- unpacked. push ( message) ;
90- }
91- }
139+ if inside_sns ( ) {
140+ records
141+ . iter ( )
142+ . filter ( |m| m. body . is_some ( ) )
143+ . map ( |m| {
144+ let value: SNSWrapper = serde_json:: from_str ( m. body . as_ref ( ) . unwrap ( ) )
145+ . expect ( "Failed to deserialize SNS payload as JSON" ) ;
146+ value. to_vec ( )
147+ } )
148+ . flatten ( )
149+ . collect :: < Vec < String > > ( )
150+ } else {
151+ records
152+ . iter ( )
153+ . filter ( |m| m. body . is_some ( ) )
154+ . map ( |m| m. body . as_ref ( ) . unwrap ( ) . clone ( ) )
155+ . collect :: < Vec < String > > ( )
92156 }
93- unpacked
94157}
95158
96159#[ derive( Debug , Deserialize ) ]
@@ -99,10 +162,64 @@ struct SNSWrapper {
99162 records : Vec < serde_json:: Value > ,
100163}
101164
165+ impl SNSWrapper {
166+ /// to_vec() will handle converting all the deserialized JSON inside the wrapper back into
167+ /// strings for passing deeper into oxbow
168+ fn to_vec ( & self ) -> Vec < String > {
169+ self . records
170+ . iter ( )
171+ . map ( |v| serde_json:: to_string ( & v) . expect ( "Failed to reserialize SNS JSON" ) )
172+ . collect ( )
173+ }
174+ }
175+
176+ /// Return true if the function expecs an SNS envelope
177+ fn inside_sns ( ) -> bool {
178+ std:: env:: var ( "UNWRAP_SNS_ENVELOPE" ) . is_ok ( )
179+ }
180+
181+ /// These tests are for the BufferMore functionality
182+ #[ cfg( test) ]
183+ mod buffer_more_tests {
184+ use super :: * ;
185+
186+ use aws_sdk_sqs:: types:: Message ;
187+ use serial_test:: serial;
188+
189+ #[ serial]
190+ #[ test]
191+ fn test_extract_direct ( ) {
192+ let message = Message :: builder ( ) . body ( "hello" ) . build ( ) ;
193+
194+ let res = extract_json_from_sqs_direct ( vec ! [ message] ) ;
195+ assert_eq ! ( res, vec![ "hello" . to_string( ) ] ) ;
196+ }
197+
198+ #[ serial]
199+ #[ test]
200+ fn test_extract_direct_with_sns ( ) {
201+ let body = r#"{"Records":[{"eventVersion":"2.1"}]}"# ;
202+ let message = Message :: builder ( ) . body ( body) . build ( ) ;
203+
204+ unsafe {
205+ std:: env:: set_var ( "UNWRAP_SNS_ENVELOPE" , "true" ) ;
206+ }
207+
208+ let res = extract_json_from_sqs_direct ( vec ! [ message] ) ;
209+
210+ unsafe {
211+ std:: env:: remove_var ( "UNWRAP_SNS_ENVELOPE" ) ;
212+ }
213+ assert_eq ! ( res, vec![ r#"{"eventVersion":"2.1"}"# ] ) ;
214+ }
215+ }
216+
102217#[ cfg( test) ]
103218mod tests {
104219 use super :: * ;
220+ use serial_test:: serial;
105221
222+ #[ serial]
106223 #[ test]
107224 fn test_extract_data ( ) {
108225 let buf = r#"{
@@ -127,6 +244,7 @@ mod tests {
127244 assert_eq ! ( values, expected) ;
128245 }
129246
247+ #[ serial]
130248 #[ test]
131249 fn test_unwrap_sns ( ) {
132250 // This is an example of what a full message can look like
@@ -139,8 +257,16 @@ mod tests {
139257 let event: SqsEvent = SqsEvent {
140258 records : vec ! [ message] ,
141259 } ;
142- let values = unwrap_sns_payload ( & event. records ) ;
143- assert_eq ! ( values. len( ) , event. records. len( ) ) ;
144- assert_eq ! ( Some ( r#"{"eventVersion":"2.1"}"# ) , values[ 0 ] . body. as_deref( ) ) ;
260+
261+ unsafe {
262+ std:: env:: set_var ( "UNWRAP_SNS_ENVELOPE" , "true" ) ;
263+ }
264+
265+ let values = extract_json_from_records ( & event. records ) ;
266+
267+ unsafe {
268+ std:: env:: remove_var ( "UNWRAP_SNS_ENVELOPE" ) ;
269+ }
270+ assert_eq ! ( values, vec![ r#"{"eventVersion":"2.1"}"# ] ) ;
145271 }
146272}
0 commit comments