@@ -149,3 +149,192 @@ impl<B: IcebergWriterBuilder> RollingFileWriter for RollingDataFileWriter<B> {
149
149
self . written_size + input_size > self . target_size
150
150
}
151
151
}
152
+
153
+ #[ cfg( test) ]
154
+ mod tests {
155
+ use std:: collections:: HashMap ;
156
+ use std:: sync:: Arc ;
157
+
158
+ use arrow_array:: { Int32Array , StringArray } ;
159
+ use arrow_schema:: { DataType , Field , Schema as ArrowSchema } ;
160
+ use parquet:: arrow:: PARQUET_FIELD_ID_META_KEY ;
161
+ use parquet:: file:: properties:: WriterProperties ;
162
+ use tempfile:: TempDir ;
163
+
164
+ use super :: * ;
165
+ use crate :: io:: FileIOBuilder ;
166
+ use crate :: spec:: { DataFileFormat , NestedField , PrimitiveType , Schema , Type } ;
167
+ use crate :: writer:: base_writer:: data_file_writer:: DataFileWriterBuilder ;
168
+ use crate :: writer:: file_writer:: ParquetWriterBuilder ;
169
+ use crate :: writer:: file_writer:: location_generator:: DefaultFileNameGenerator ;
170
+ use crate :: writer:: file_writer:: location_generator:: test:: MockLocationGenerator ;
171
+ use crate :: writer:: tests:: check_parquet_data_file;
172
+ use crate :: writer:: { IcebergWriter , IcebergWriterBuilder , RecordBatch } ;
173
+
174
+ #[ tokio:: test]
175
+ async fn test_rolling_writer_basic ( ) -> Result < ( ) > {
176
+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
177
+ let file_io = FileIOBuilder :: new_fs_io ( ) . build ( ) . unwrap ( ) ;
178
+ let location_gen =
179
+ MockLocationGenerator :: new ( temp_dir. path ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ) ;
180
+ let file_name_gen =
181
+ DefaultFileNameGenerator :: new ( "test" . to_string ( ) , None , DataFileFormat :: Parquet ) ;
182
+
183
+ // Create schema
184
+ let schema = Schema :: builder ( )
185
+ . with_schema_id ( 1 )
186
+ . with_fields ( vec ! [
187
+ NestedField :: required( 1 , "id" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
188
+ NestedField :: required( 2 , "name" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
189
+ ] )
190
+ . build ( ) ?;
191
+
192
+ // Create writer builders
193
+ let parquet_writer_builder = ParquetWriterBuilder :: new (
194
+ WriterProperties :: builder ( ) . build ( ) ,
195
+ Arc :: new ( schema) ,
196
+ file_io. clone ( ) ,
197
+ location_gen,
198
+ file_name_gen,
199
+ ) ;
200
+ let data_file_writer_builder = DataFileWriterBuilder :: new ( parquet_writer_builder, None , 0 ) ;
201
+
202
+ // Set a large target size so no rolling occurs
203
+ let rolling_writer_builder = RollingDataFileWriterBuilder :: new (
204
+ data_file_writer_builder,
205
+ 1024 * 1024 , // 1MB, large enough to not trigger rolling
206
+ ) ;
207
+
208
+ // Create writer
209
+ let mut writer = rolling_writer_builder. build ( ) . await ?;
210
+
211
+ // Create test data
212
+ let arrow_schema = ArrowSchema :: new ( vec ! [
213
+ Field :: new( "id" , DataType :: Int32 , false ) . with_metadata( HashMap :: from( [ (
214
+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
215
+ 1 . to_string( ) ,
216
+ ) ] ) ) ,
217
+ Field :: new( "name" , DataType :: Utf8 , false ) . with_metadata( HashMap :: from( [ (
218
+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
219
+ 2 . to_string( ) ,
220
+ ) ] ) ) ,
221
+ ] ) ;
222
+
223
+ let batch = RecordBatch :: try_new ( Arc :: new ( arrow_schema) , vec ! [
224
+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
225
+ Arc :: new( StringArray :: from( vec![ "Alice" , "Bob" , "Charlie" ] ) ) ,
226
+ ] ) ?;
227
+
228
+ // Write data
229
+ writer. write ( batch. clone ( ) ) . await ?;
230
+
231
+ // Close writer and get data files
232
+ let data_files = writer. close ( ) . await ?;
233
+
234
+ // Verify only one file was created
235
+ assert_eq ! (
236
+ data_files. len( ) ,
237
+ 1 ,
238
+ "Expected only one data file to be created"
239
+ ) ;
240
+
241
+ // Verify file content
242
+ check_parquet_data_file ( & file_io, & data_files[ 0 ] , & batch) . await ;
243
+
244
+ Ok ( ( ) )
245
+ }
246
+
247
+ #[ tokio:: test]
248
+ async fn test_rolling_writer_with_rolling ( ) -> Result < ( ) > {
249
+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
250
+ let file_io = FileIOBuilder :: new_fs_io ( ) . build ( ) . unwrap ( ) ;
251
+ let location_gen =
252
+ MockLocationGenerator :: new ( temp_dir. path ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ) ;
253
+ let file_name_gen =
254
+ DefaultFileNameGenerator :: new ( "test" . to_string ( ) , None , DataFileFormat :: Parquet ) ;
255
+
256
+ // Create schema
257
+ let schema = Schema :: builder ( )
258
+ . with_schema_id ( 1 )
259
+ . with_fields ( vec ! [
260
+ NestedField :: required( 1 , "id" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
261
+ NestedField :: required( 2 , "name" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
262
+ ] )
263
+ . build ( ) ?;
264
+
265
+ // Create writer builders
266
+ let parquet_writer_builder = ParquetWriterBuilder :: new (
267
+ WriterProperties :: builder ( ) . build ( ) ,
268
+ Arc :: new ( schema) ,
269
+ file_io. clone ( ) ,
270
+ location_gen,
271
+ file_name_gen,
272
+ ) ;
273
+ let data_file_writer_builder = DataFileWriterBuilder :: new ( parquet_writer_builder, None , 0 ) ;
274
+
275
+ // Set a very small target size to trigger rolling
276
+ let rolling_writer_builder = RollingDataFileWriterBuilder :: new (
277
+ data_file_writer_builder,
278
+ 100 , // Very small target size to ensure rolling
279
+ ) ;
280
+
281
+ // Create writer
282
+ let mut writer = rolling_writer_builder. build ( ) . await ?;
283
+
284
+ // Create test data
285
+ let arrow_schema = ArrowSchema :: new ( vec ! [
286
+ Field :: new( "id" , DataType :: Int32 , false ) . with_metadata( HashMap :: from( [ (
287
+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
288
+ 1 . to_string( ) ,
289
+ ) ] ) ) ,
290
+ Field :: new( "name" , DataType :: Utf8 , false ) . with_metadata( HashMap :: from( [ (
291
+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
292
+ 2 . to_string( ) ,
293
+ ) ] ) ) ,
294
+ ] ) ;
295
+
296
+ // Create multiple batches to trigger rolling
297
+ let batch1 = RecordBatch :: try_new ( Arc :: new ( arrow_schema. clone ( ) ) , vec ! [
298
+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
299
+ Arc :: new( StringArray :: from( vec![ "Alice" , "Bob" , "Charlie" ] ) ) ,
300
+ ] ) ?;
301
+
302
+ let batch2 = RecordBatch :: try_new ( Arc :: new ( arrow_schema. clone ( ) ) , vec ! [
303
+ Arc :: new( Int32Array :: from( vec![ 4 , 5 , 6 ] ) ) ,
304
+ Arc :: new( StringArray :: from( vec![ "Dave" , "Eve" , "Frank" ] ) ) ,
305
+ ] ) ?;
306
+
307
+ let batch3 = RecordBatch :: try_new ( Arc :: new ( arrow_schema) , vec ! [
308
+ Arc :: new( Int32Array :: from( vec![ 7 , 8 , 9 ] ) ) ,
309
+ Arc :: new( StringArray :: from( vec![ "Grace" , "Heidi" , "Ivan" ] ) ) ,
310
+ ] ) ?;
311
+
312
+ // Write data
313
+ writer. write ( batch1. clone ( ) ) . await ?;
314
+ writer. write ( batch2. clone ( ) ) . await ?;
315
+ writer. write ( batch3. clone ( ) ) . await ?;
316
+
317
+ // Close writer and get data files
318
+ let data_files = writer. close ( ) . await ?;
319
+
320
+ // Verify multiple files were created (at least 2)
321
+ assert ! (
322
+ data_files. len( ) > 1 ,
323
+ "Expected multiple data files to be created, got {}" ,
324
+ data_files. len( )
325
+ ) ;
326
+
327
+ // Verify total record count across all files
328
+ let total_records: u64 = data_files. iter ( ) . map ( |file| file. record_count ) . sum ( ) ;
329
+ assert_eq ! (
330
+ total_records, 9 ,
331
+ "Expected 9 total records across all files"
332
+ ) ;
333
+
334
+ // Verify each file has the correct content
335
+ // Note: We can't easily verify which records went to which file without more complex logic,
336
+ // but we can verify the total count and that each file has valid content
337
+
338
+ Ok ( ( ) )
339
+ }
340
+ }
0 commit comments