@@ -23,8 +23,8 @@ use std::path::{Path, PathBuf};
2323use std:: ptr:: NonNull ;
2424
2525use arrow:: array:: ArrayData ;
26- use arrow:: datatypes:: SchemaRef ;
27- use arrow:: ipc:: reader:: FileReader ;
26+ use arrow:: datatypes:: { Schema , SchemaRef } ;
27+ use arrow:: ipc:: { reader:: StreamReader , writer :: StreamWriter } ;
2828use arrow:: record_batch:: RecordBatch ;
2929use log:: debug;
3030use tokio:: sync:: mpsc:: Sender ;
@@ -34,7 +34,6 @@ use datafusion_execution::disk_manager::RefCountedTempFile;
3434use datafusion_execution:: memory_pool:: human_readable_size;
3535use datafusion_execution:: SendableRecordBatchStream ;
3636
37- use crate :: common:: IPCWriter ;
3837use crate :: stream:: RecordBatchReceiverStream ;
3938
4039/// Read spilled batches from the disk
@@ -59,13 +58,13 @@ pub(crate) fn read_spill_as_stream(
5958///
6059/// Returns total number of the rows spilled to disk.
6160pub ( crate ) fn spill_record_batches (
62- batches : Vec < RecordBatch > ,
61+ batches : & [ RecordBatch ] ,
6362 path : PathBuf ,
6463 schema : SchemaRef ,
6564) -> Result < ( usize , usize ) > {
66- let mut writer = IPCWriter :: new ( path. as_ref ( ) , schema. as_ref ( ) ) ?;
65+ let mut writer = IPCStreamWriter :: new ( path. as_ref ( ) , schema. as_ref ( ) ) ?;
6766 for batch in batches {
68- writer. write ( & batch) ?;
67+ writer. write ( batch) ?;
6968 }
7069 writer. finish ( ) ?;
7170 debug ! (
@@ -79,7 +78,7 @@ pub(crate) fn spill_record_batches(
7978
8079fn read_spill ( sender : Sender < Result < RecordBatch > > , path : & Path ) -> Result < ( ) > {
8180 let file = BufReader :: new ( File :: open ( path) ?) ;
82- let reader = FileReader :: try_new ( file, None ) ?;
81+ let reader = StreamReader :: try_new ( file, None ) ?;
8382 for batch in reader {
8483 sender
8584 . blocking_send ( batch. map_err ( Into :: into) )
@@ -98,7 +97,7 @@ pub fn spill_record_batch_by_size(
9897) -> Result < ( ) > {
9998 let mut offset = 0 ;
10099 let total_rows = batch. num_rows ( ) ;
101- let mut writer = IPCWriter :: new ( & path, schema. as_ref ( ) ) ?;
100+ let mut writer = IPCStreamWriter :: new ( & path, schema. as_ref ( ) ) ?;
102101
103102 while offset < total_rows {
104103 let length = std:: cmp:: min ( total_rows - offset, batch_size_rows) ;
@@ -130,7 +129,7 @@ pub fn spill_record_batch_by_size(
130129/// {xxxxxxxxxxxxxxxxxxx} <--- buffer
131130/// ^ ^ ^ ^
132131/// | | | |
133- /// col1->{ } | |
132+ /// col1->{ } | |
134133/// col2--------->{ }
135134///
136135/// In the above case, `get_record_batch_memory_size` will return the size of
@@ -179,17 +178,64 @@ fn count_array_data_memory_size(
179178 }
180179}
181180
181+ /// Write in Arrow IPC Stream format to a file.
182+ ///
183+ /// Stream format is used for spill because it supports dictionary replacement, and the random
184+ /// access of IPC File format is not needed (IPC File format doesn't support dictionary replacement).
185+ struct IPCStreamWriter {
186+ /// Inner writer
187+ pub writer : StreamWriter < File > ,
188+ /// Batches written
189+ pub num_batches : usize ,
190+ /// Rows written
191+ pub num_rows : usize ,
192+ /// Bytes written
193+ pub num_bytes : usize ,
194+ }
195+
196+ impl IPCStreamWriter {
197+ /// Create new writer
198+ pub fn new ( path : & Path , schema : & Schema ) -> Result < Self > {
199+ let file = File :: create ( path) . map_err ( |e| {
200+ exec_datafusion_err ! ( "Failed to create partition file at {path:?}: {e:?}" )
201+ } ) ?;
202+ Ok ( Self {
203+ num_batches : 0 ,
204+ num_rows : 0 ,
205+ num_bytes : 0 ,
206+ writer : StreamWriter :: try_new ( file, schema) ?,
207+ } )
208+ }
209+
210+ /// Write one single batch
211+ pub fn write ( & mut self , batch : & RecordBatch ) -> Result < ( ) > {
212+ self . writer . write ( batch) ?;
213+ self . num_batches += 1 ;
214+ self . num_rows += batch. num_rows ( ) ;
215+ let num_bytes: usize = batch. get_array_memory_size ( ) ;
216+ self . num_bytes += num_bytes;
217+ Ok ( ( ) )
218+ }
219+
220+ /// Finish the writer
221+ pub fn finish ( & mut self ) -> Result < ( ) > {
222+ self . writer . finish ( ) . map_err ( Into :: into)
223+ }
224+ }
225+
182226#[ cfg( test) ]
183227mod tests {
184228 use super :: * ;
185229 use crate :: spill:: { spill_record_batch_by_size, spill_record_batches} ;
186230 use crate :: test:: build_table_i32;
187231 use arrow:: array:: { Float64Array , Int32Array , ListArray } ;
232+ use arrow:: compute:: cast;
188233 use arrow:: datatypes:: { DataType , Field , Int32Type , Schema } ;
189234 use arrow:: record_batch:: RecordBatch ;
190235 use datafusion_common:: Result ;
191236 use datafusion_execution:: disk_manager:: DiskManagerConfig ;
192237 use datafusion_execution:: DiskManager ;
238+ use itertools:: Itertools ;
193239 use std:: fs:: File ;
194240 use std:: io:: BufReader ;
195241 use std:: sync:: Arc ;
@@ -214,18 +260,85 @@ mod tests {
214260 let schema = batch1. schema ( ) ;
215261 let num_rows = batch1. num_rows ( ) + batch2. num_rows ( ) ;
216262 let ( spilled_rows, _) = spill_record_batches (
217- vec ! [ batch1, batch2] ,
263+ & [ batch1, batch2] ,
218264 spill_file. path ( ) . into ( ) ,
219265 Arc :: clone ( & schema) ,
220266 ) ?;
221267 assert_eq ! ( spilled_rows, num_rows) ;
222268
223269 let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
224- let reader = FileReader :: try_new ( file, None ) ?;
270+ let reader = StreamReader :: try_new ( file, None ) ?;
225271
226- assert_eq ! ( reader. num_batches( ) , 2 ) ;
227272 assert_eq ! ( reader. schema( ) , schema) ;
228273
274+ let batches = reader. collect_vec ( ) ;
275+ assert ! ( batches. len( ) == 2 ) ;
276+
277+ Ok ( ( ) )
278+ }
279+
280+ #[ test]
281+ fn test_batch_spill_and_read_dictionary_arrays ( ) -> Result < ( ) > {
282+ // See https://github.com/apache/datafusion/issues/4658
283+
284+ let batch1 = build_table_i32 (
285+ ( "a2" , & vec ! [ 0 , 1 , 2 ] ) ,
286+ ( "b2" , & vec ! [ 3 , 4 , 5 ] ) ,
287+ ( "c2" , & vec ! [ 4 , 5 , 6 ] ) ,
288+ ) ;
289+
290+ let batch2 = build_table_i32 (
291+ ( "a2" , & vec ! [ 10 , 11 , 12 ] ) ,
292+ ( "b2" , & vec ! [ 13 , 14 , 15 ] ) ,
293+ ( "c2" , & vec ! [ 14 , 15 , 16 ] ) ,
294+ ) ;
295+
296+ // Dictionary encode the arrays
297+ let dict_type =
298+ DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Int32 ) ) ;
299+ let dict_schema = Arc :: new ( Schema :: new ( vec ! [
300+ Field :: new( "a2" , dict_type. clone( ) , true ) ,
301+ Field :: new( "b2" , dict_type. clone( ) , true ) ,
302+ Field :: new( "c2" , dict_type. clone( ) , true ) ,
303+ ] ) ) ;
304+
305+ let batch1 = RecordBatch :: try_new (
306+ Arc :: clone ( & dict_schema) ,
307+ batch1
308+ . columns ( )
309+ . iter ( )
310+ . map ( |array| cast ( array, & dict_type) )
311+ . collect :: < Result < _ , _ > > ( ) ?,
312+ ) ?;
313+
314+ let batch2 = RecordBatch :: try_new (
315+ Arc :: clone ( & dict_schema) ,
316+ batch2
317+ . columns ( )
318+ . iter ( )
319+ . map ( |array| cast ( array, & dict_type) )
320+ . collect :: < Result < _ , _ > > ( ) ?,
321+ ) ?;
322+
323+ let disk_manager = DiskManager :: try_new ( DiskManagerConfig :: NewOs ) ?;
324+
325+ let spill_file = disk_manager. create_tmp_file ( "Test Spill" ) ?;
326+ let num_rows = batch1. num_rows ( ) + batch2. num_rows ( ) ;
327+ let ( spilled_rows, _) = spill_record_batches (
328+ & [ batch1, batch2] ,
329+ spill_file. path ( ) . into ( ) ,
330+ Arc :: clone ( & dict_schema) ,
331+ ) ?;
332+ assert_eq ! ( spilled_rows, num_rows) ;
333+
334+ let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
335+ let reader = StreamReader :: try_new ( file, None ) ?;
336+
337+ assert_eq ! ( reader. schema( ) , dict_schema) ;
338+
339+ let batches = reader. collect_vec ( ) ;
340+ assert ! ( batches. len( ) == 2 ) ;
341+
229342 Ok ( ( ) )
230343 }
231344
@@ -249,11 +362,13 @@ mod tests {
249362 ) ?;
250363
251364 let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
252- let reader = FileReader :: try_new ( file, None ) ?;
365+ let reader = StreamReader :: try_new ( file, None ) ?;
253366
254- assert_eq ! ( reader. num_batches( ) , 4 ) ;
255367 assert_eq ! ( reader. schema( ) , schema) ;
256368
369+ let batches = reader. collect_vec ( ) ;
370+ assert ! ( batches. len( ) == 4 ) ;
371+
257372 Ok ( ( ) )
258373 }
259374
0 commit comments