Skip to content

Commit e86190e

Browse files
test_utils: add utility to create test parquet tables from RecordBatch
This change adds a module `session_context` to `test_utils` which allows you to - create temp parquet files from RecordBatch - register this data in the `SessionContext` under a table name This is handy for unit tests where you typically want to execute a distributed SQL query on some small hard-coded `RecordBatch` data
1 parent 22c3802 commit e86190e

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

src/test_utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ pub mod insta;
33
pub mod localhost;
44
pub mod mock_exec;
55
pub mod parquet;
6+
pub mod session_context;
67
pub mod tpch;

src/test_utils/session_context.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use arrow::record_batch::RecordBatch;
2+
use datafusion::arrow::datatypes::SchemaRef;
3+
use datafusion::error::Result;
4+
use datafusion::execution::context::SessionContext;
5+
use datafusion::prelude::ParquetReadOptions;
6+
use parquet::arrow::ArrowWriter;
7+
use std::path::PathBuf;
8+
use uuid::Uuid;
9+
10+
/// Creates a temporary Parquet file from RecordBatches and registers it with the SessionContext
11+
/// under the provided table name. Returns the file name.
12+
///
13+
/// TODO: consider expanding this to support partitioned data
14+
pub async fn register_temp_parquet_table(
15+
table_name: &str,
16+
schema: SchemaRef,
17+
batches: Vec<RecordBatch>,
18+
ctx: &SessionContext,
19+
) -> Result<PathBuf> {
20+
if batches.is_empty() {
21+
return Err(datafusion::error::DataFusionError::Execution(
22+
"Cannot create Parquet file from empty batch list".to_string(),
23+
));
24+
}
25+
for batch in &batches {
26+
if batch.schema() != schema {
27+
return Err(datafusion::error::DataFusionError::Execution(
28+
"All batches must have the same schema".to_string(),
29+
));
30+
}
31+
}
32+
33+
let temp_dir = std::env::temp_dir();
34+
let file_id = Uuid::new_v4();
35+
let temp_file_path = temp_dir.join(format!("{}_{}.parquet", table_name, file_id,));
36+
37+
let file = std::fs::File::create(&temp_file_path)?;
38+
let schema = batches[0].schema();
39+
let mut writer = ArrowWriter::try_new(file, schema, None)?;
40+
41+
for batch in batches {
42+
writer.write(&batch)?;
43+
}
44+
writer.close()?;
45+
46+
ctx.register_parquet(
47+
table_name,
48+
temp_file_path.to_string_lossy().as_ref(),
49+
ParquetReadOptions::default(),
50+
)
51+
.await?;
52+
53+
Ok(temp_file_path)
54+
}
55+
56+
#[cfg(test)]
57+
mod tests {
58+
use super::*;
59+
use arrow::array::{Int32Array, StringArray};
60+
use arrow::datatypes::{DataType, Field, Schema};
61+
use tokio::fs::remove_file;
62+
63+
use std::sync::Arc;
64+
65+
#[tokio::test]
66+
async fn test_register_temp_parquet_table() {
67+
let ctx = SessionContext::new();
68+
69+
// Create test data
70+
let schema = Arc::new(Schema::new(vec![
71+
Field::new("id", DataType::Int32, false),
72+
Field::new("name", DataType::Utf8, false),
73+
]));
74+
75+
let batch = RecordBatch::try_new(
76+
schema.clone(),
77+
vec![
78+
Arc::new(Int32Array::from(vec![1, 2, 3])),
79+
Arc::new(StringArray::from(vec!["a", "b", "c"])),
80+
],
81+
)
82+
.unwrap();
83+
84+
// Register temp table
85+
let temp_file =
86+
register_temp_parquet_table("test_table", schema.clone(), vec![batch], &ctx)
87+
.await
88+
.unwrap();
89+
90+
// Verify we can query it
91+
let df = ctx.sql("SELECT COUNT(*) FROM test_table").await.unwrap();
92+
let results = df.collect().await.unwrap();
93+
94+
assert_eq!(results.len(), 1);
95+
assert_eq!(results[0].num_rows(), 1);
96+
97+
let _ = remove_file(temp_file).await;
98+
}
99+
}

0 commit comments

Comments
 (0)