Skip to content

Commit 6457ff2

Browse files
committed
Refactor record-batch writing through a shared sink harness
Made-with: Cursor
1 parent 82d25ea commit 6457ff2

File tree

7 files changed

+144
-40
lines changed

7 files changed

+144
-40
lines changed

src/pipeline.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! The `pipeline` module is the core of the datu crate.
22
33
pub mod avro;
4+
pub mod batch_write;
45
pub mod csv;
56
pub mod dataframe;
67
pub mod datasource;

src/pipeline/avro.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use crate::pipeline::RecordBatchReaderSource;
1313
use crate::pipeline::Source;
1414
use crate::pipeline::Step;
1515
use crate::pipeline::WriteArgs;
16+
use crate::pipeline::batch_write::BatchWriteSink;
17+
use crate::pipeline::batch_write::write_record_batches_with_sink;
1618

1719
/// Pipeline step that reads an Avro file and produces a record batch reader.
1820
pub struct ReadAvroStep {
@@ -144,15 +146,30 @@ pub struct WriteAvroResult {}
144146

145147
/// Write record batches from a reader to an Avro file.
146148
pub fn write_record_batches(path: &str, reader: &mut dyn RecordBatchReader) -> Result<()> {
147-
let file = std::fs::File::create(path).map_err(Error::IoError)?;
148-
let schema = reader.schema();
149-
let mut writer = AvroWriter::new(file, (*schema).clone()).map_err(Error::ArrowError)?;
150-
for batch in reader {
151-
let batch = batch.map_err(Error::ArrowError)?;
152-
writer.write(&batch).map_err(Error::ArrowError)?;
149+
write_record_batches_with_sink(path, reader, AvroSink::new)
150+
}
151+
152+
struct AvroSink {
153+
writer: AvroWriter<std::fs::File>,
154+
}
155+
156+
impl AvroSink {
157+
fn new(path: &str, schema: arrow::datatypes::SchemaRef) -> Result<Self> {
158+
let file = std::fs::File::create(path).map_err(Error::IoError)?;
159+
let writer = AvroWriter::new(file, (*schema).clone()).map_err(Error::ArrowError)?;
160+
Ok(Self { writer })
161+
}
162+
}
163+
164+
impl BatchWriteSink for AvroSink {
165+
fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
166+
self.writer.write(batch).map_err(Error::ArrowError)
167+
}
168+
169+
fn finish(mut self) -> Result<()> {
170+
self.writer.finish().map_err(Error::ArrowError)?;
171+
Ok(())
153172
}
154-
writer.finish().map_err(Error::ArrowError)?;
155-
Ok(())
156173
}
157174

158175
#[async_trait(?Send)]

src/pipeline/batch_write.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use arrow::array::RecordBatchReader;
2+
use arrow::datatypes::SchemaRef;
3+
use arrow::record_batch::RecordBatch;
4+
5+
use crate::Error;
6+
use crate::Result;
7+
8+
/// Per-format sink adapter for writing record batches.
9+
pub trait BatchWriteSink {
10+
fn write_batch(&mut self, batch: &RecordBatch) -> Result<()>;
11+
fn finish(self) -> Result<()>;
12+
}
13+
14+
/// Shared harness for batch-oriented file writers.
15+
pub fn write_record_batches_with_sink<S, BuildSink>(
16+
path: &str,
17+
reader: &mut dyn RecordBatchReader,
18+
build_sink: BuildSink,
19+
) -> Result<()>
20+
where
21+
S: BatchWriteSink,
22+
BuildSink: FnOnce(&str, SchemaRef) -> Result<S>,
23+
{
24+
let schema = reader.schema();
25+
let mut sink = build_sink(path, schema)?;
26+
27+
for batch in reader {
28+
let batch = batch.map_err(Error::ArrowError)?;
29+
sink.write_batch(&batch)?;
30+
}
31+
32+
sink.finish()
33+
}

src/pipeline/csv.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use crate::pipeline::Source;
1010
use crate::pipeline::Step;
1111
use crate::pipeline::VecRecordBatchReader;
1212
use crate::pipeline::WriteArgs;
13+
use crate::pipeline::batch_write::BatchWriteSink;
14+
use crate::pipeline::batch_write::write_record_batches_with_sink;
1315

1416
/// Pipeline step that reads a CSV file and produces a record batch reader.
1517
/// Uses DataFusion for schema inference and type detection.
@@ -47,21 +49,43 @@ pub struct WriteCsvStep {
4749
/// Result of successfully writing a CSV file.
4850
pub struct WriteCsvResult {}
4951

52+
/// Write record batches from a reader to a CSV file.
53+
pub fn write_record_batches(path: &str, reader: &mut dyn RecordBatchReader) -> Result<()> {
54+
write_record_batches_with_sink(path, reader, CsvSink::new)
55+
}
56+
57+
struct CsvSink {
58+
writer: arrow::csv::Writer<std::fs::File>,
59+
}
60+
61+
impl CsvSink {
62+
fn new(path: &str, _schema: arrow::datatypes::SchemaRef) -> Result<Self> {
63+
let file = std::fs::File::create(path).map_err(Error::IoError)?;
64+
Ok(Self {
65+
writer: arrow::csv::Writer::new(file),
66+
})
67+
}
68+
}
69+
70+
impl BatchWriteSink for CsvSink {
71+
fn write_batch(&mut self, batch: &arrow::record_batch::RecordBatch) -> Result<()> {
72+
self.writer.write(batch).map_err(Error::ArrowError)
73+
}
74+
75+
fn finish(self) -> Result<()> {
76+
Ok(())
77+
}
78+
}
79+
5080
#[async_trait(?Send)]
5181
impl Step for WriteCsvStep {
5282
type Input = ();
5383
type Output = WriteCsvResult;
5484

5585
async fn execute(self, _input: Self::Input) -> Result<Self::Output> {
56-
let path = self.args.path.as_str();
57-
let file = std::fs::File::create(path).map_err(Error::IoError)?;
58-
let mut writer = arrow::csv::Writer::new(file);
5986
let mut source = self.source;
60-
let reader = source.get()?;
61-
for batch in reader {
62-
let batch = batch.map_err(Error::ArrowError)?;
63-
writer.write(&batch).map_err(Error::ArrowError)?;
64-
}
87+
let mut reader = source.get()?;
88+
write_record_batches(self.args.path.as_str(), &mut *reader)?;
6589
Ok(WriteCsvResult {})
6690
}
6791
}

src/pipeline/dataframe.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ impl Step for DataFrameWriter {
8686
parquet::write_record_batches(output_path, &mut reader)?;
8787
}
8888
FileType::Csv => {
89-
let file = std::fs::File::create(output_path).map_err(Error::IoError)?;
90-
let mut writer = arrow::csv::Writer::new(file);
91-
for batch in &mut reader {
92-
let batch = batch.map_err(Error::ArrowError)?;
93-
writer.write(&batch).map_err(Error::ArrowError)?;
94-
}
89+
crate::pipeline::csv::write_record_batches(output_path, &mut reader)?;
9590
}
9691
FileType::Json => {
9792
let file = std::fs::File::create(output_path).map_err(Error::IoError)?;

src/pipeline/orc.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use crate::pipeline::RecordBatchReaderSource;
1111
use crate::pipeline::Source;
1212
use crate::pipeline::Step;
1313
use crate::pipeline::WriteArgs;
14+
use crate::pipeline::batch_write::BatchWriteSink;
15+
use crate::pipeline::batch_write::write_record_batches_with_sink;
1416

1517
/// Pipeline step that reads an ORC file and produces a record batch reader.
1618
pub struct ReadOrcStep {
@@ -52,17 +54,32 @@ pub struct WriteOrcResult {}
5254

5355
/// Write record batches from a reader to an ORC file.
5456
pub fn write_record_batches(path: &str, reader: &mut dyn RecordBatchReader) -> Result<()> {
55-
let file = std::fs::File::create(path).map_err(Error::IoError)?;
56-
let schema = reader.schema();
57-
let mut writer = ArrowWriterBuilder::new(file, schema)
58-
.try_build()
59-
.map_err(Error::OrcError)?;
60-
for batch in reader {
61-
let batch = batch.map_err(Error::ArrowError)?;
62-
writer.write(&batch).map_err(Error::OrcError)?;
57+
write_record_batches_with_sink(path, reader, OrcSink::new)
58+
}
59+
60+
struct OrcSink {
61+
writer: orc_rust::arrow_writer::ArrowWriter<std::fs::File>,
62+
}
63+
64+
impl OrcSink {
65+
fn new(path: &str, schema: arrow::datatypes::SchemaRef) -> Result<Self> {
66+
let file = std::fs::File::create(path).map_err(Error::IoError)?;
67+
let writer = ArrowWriterBuilder::new(file, schema)
68+
.try_build()
69+
.map_err(Error::OrcError)?;
70+
Ok(Self { writer })
71+
}
72+
}
73+
74+
impl BatchWriteSink for OrcSink {
75+
fn write_batch(&mut self, batch: &arrow::record_batch::RecordBatch) -> Result<()> {
76+
self.writer.write(batch).map_err(Error::OrcError)
77+
}
78+
79+
fn finish(self) -> Result<()> {
80+
self.writer.close().map_err(Error::OrcError)?;
81+
Ok(())
6382
}
64-
writer.close().map_err(Error::OrcError)?;
65-
Ok(())
6683
}
6784

6885
#[async_trait(?Send)]

src/pipeline/parquet.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use crate::pipeline::RecordBatchReaderSource;
1111
use crate::pipeline::Source;
1212
use crate::pipeline::Step;
1313
use crate::pipeline::WriteArgs;
14+
use crate::pipeline::batch_write::BatchWriteSink;
15+
use crate::pipeline::batch_write::write_record_batches_with_sink;
1416

1517
/// Pipeline step that reads a Parquet file and produces a record batch reader.
1618
pub struct ReadParquetStep {
@@ -62,15 +64,30 @@ impl Step for WriteParquetStep {
6264

6365
/// Write record batches from a reader to a Parquet file.
6466
pub fn write_record_batches(path: &str, reader: &mut dyn RecordBatchReader) -> Result<()> {
65-
let file = std::fs::File::create(path).map_err(Error::IoError)?;
66-
let schema = reader.schema();
67-
let mut writer = ArrowWriter::try_new(file, schema, None).map_err(Error::ParquetError)?;
68-
for batch in reader {
69-
let batch = batch.map_err(Error::ArrowError)?;
70-
writer.write(&batch).map_err(Error::ParquetError)?;
67+
write_record_batches_with_sink(path, reader, ParquetSink::new)
68+
}
69+
70+
struct ParquetSink {
71+
writer: ArrowWriter<std::fs::File>,
72+
}
73+
74+
impl ParquetSink {
75+
fn new(path: &str, schema: arrow::datatypes::SchemaRef) -> Result<Self> {
76+
let file = std::fs::File::create(path).map_err(Error::IoError)?;
77+
let writer = ArrowWriter::try_new(file, schema, None).map_err(Error::ParquetError)?;
78+
Ok(Self { writer })
79+
}
80+
}
81+
82+
impl BatchWriteSink for ParquetSink {
83+
fn write_batch(&mut self, batch: &arrow::record_batch::RecordBatch) -> Result<()> {
84+
self.writer.write(batch).map_err(Error::ParquetError)
85+
}
86+
87+
fn finish(self) -> Result<()> {
88+
self.writer.close().map_err(Error::ParquetError)?;
89+
Ok(())
7190
}
72-
writer.close().map_err(Error::ParquetError)?;
73-
Ok(())
7491
}
7592

7693
#[cfg(test)]

0 commit comments

Comments
 (0)