11import os
22from collections import deque
3- from itertools import batched
43from pathlib import Path
54from typing import Callable
65from unittest import mock
1716from event .json import json_load
1817from moto import mock_aws
1918from mypy_boto3_s3 import S3Client
20- from sds .epr .bulk_create .bulk_load_fanout import (
21- FANOUT ,
22- calculate_batch_size ,
23- count_indexes ,
24- )
19+ from sds .epr .bulk_create .bulk_load_fanout import count_indexes
2520
2621from etl .sds .worker .bulk .tests .test_bulk_e2e import PATH_TO_STAGE_DATA
2722
@@ -78,30 +73,15 @@ def decompress(obj: dict) -> dict:
7873 return obj
7974
8075
81- @pytest .mark .parametrize (
82- ("n_batches" , "sequence_length" , "expected_batch_size" , "expected_n_batches" ),
83- ((4 , 100 , 25 , 4 ), (3 , 100 , 34 , 3 ), (16 , 30 , 2 , 15 )),
84- )
85- def test_calculate_batch_size_general (
86- n_batches : int ,
87- sequence_length : int ,
88- expected_batch_size : int ,
89- expected_n_batches : int ,
90- ):
91- n_batches = n_batches
92- sequence = list (range (sequence_length ))
93- batch_size = calculate_batch_size (sequence , n_batches )
94- assert batch_size == expected_batch_size
95-
96- batches = list (batched (sequence , batch_size ))
97- assert len (batches ) == expected_n_batches
98-
99-
10076def test_load_worker_fanout (
101- put_object : Callable [[str ], None ], get_object : Callable [[str ], bytes ]
77+ put_object : Callable [[str ], None ],
78+ get_object : Callable [[str ], bytes ],
10279):
80+ _EACH_FANOUT_BATCH_SIZE = 10
10381 from etl .sds .worker .bulk .load_bulk_fanout import load_bulk_fanout
10482
83+ load_bulk_fanout .EACH_FANOUT_BATCH_SIZE = _EACH_FANOUT_BATCH_SIZE
84+
10585 # Initial state
10686 with open (PATH_TO_STAGE_DATA / "2.transform_output.json" ) as f :
10787 input_data : list [dict [str , dict ]] = json_load (f )
@@ -114,31 +94,36 @@ def test_load_worker_fanout(
11494 # Execute the load worker
11595 responses = load_bulk_fanout .handler (event = {}, context = None )
11696
117- assert len (responses ) == FANOUT
118- assert responses == [
97+ * head_responses , tail_response = responses
98+
99+ assert len (head_responses ) > 1
100+
101+ expected_head_responses = [
119102 {
120103 "stage_name" : "load_bulk_fanout" ,
121- "processed_records" : 10 ,
104+ "processed_records" : _EACH_FANOUT_BATCH_SIZE ,
122105 "unprocessed_records" : 0 ,
123106 "s3_input_path" : f"s3://my-bucket/input--load/unprocessed.{ i } " ,
124107 "error_message" : None ,
125108 }
126- for i in range (0 , FANOUT - 1 )
127- ] + [
128- {
129- "stage_name" : "load_bulk_fanout" ,
130- "processed_records" : 7 ,
131- "unprocessed_records" : 0 ,
132- "s3_input_path" : f"s3://my-bucket/input--load/unprocessed.{ FANOUT - 1 } " ,
133- "error_message" : None ,
134- },
109+ for i in range (len (head_responses ))
135110 ]
111+ assert head_responses == expected_head_responses
112+
113+ tail_processed_records = tail_response .pop ("processed_records" )
114+ assert tail_processed_records <= _EACH_FANOUT_BATCH_SIZE
115+ assert tail_response == {
116+ "stage_name" : "load_bulk_fanout" ,
117+ "unprocessed_records" : 0 ,
118+ "s3_input_path" : f"s3://my-bucket/input--load/unprocessed.{ len (head_responses )} " ,
119+ "error_message" : None ,
120+ }
136121
137122 # Final state
138123 final_processed_data = pkl_loads_lz4 (get_object (key = WorkerKey .LOAD ))
139124 assert final_processed_data == deque ([])
140125 total_size = 0
141- for i in range (10 ):
126+ for i in range (_EACH_FANOUT_BATCH_SIZE ):
142127 final_unprocessed_data = pkl_loads_lz4 (get_object (key = f"{ WorkerKey .LOAD } .{ i } " ))
143128 assert isinstance (final_unprocessed_data , deque )
144129 total_size += len (final_unprocessed_data )
@@ -164,3 +149,8 @@ def test_load_worker_fanout(
164149 expected_total_size += count_indexes (obj )
165150
166151 assert total_size == expected_total_size
152+
153+ total_processed_records_from_response = (
154+ tail_processed_records + _EACH_FANOUT_BATCH_SIZE * len (head_responses )
155+ )
156+ assert total_size == total_processed_records_from_response
0 commit comments