Skip to content

Commit 1bc4841

Browse files
Adds test ensuring dictionary corruption does not occur anymore (#208)
* Add docs to `start_localhost_context` * Add test for nullification bug * Fix typo * Add test warning
1 parent 6f69516 commit 1bc4841

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

src/stage.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use uuid::Uuid;
6868
/// The receiving ArrowFlightEndpoint will then execute Stage 2 and will repeat this process.
6969
///
7070
/// When Stage 4 is executed, it has no input tasks, so it is assumed that the plan included in that
71-
/// Stage can complete on its own; its likely holding a leaf node in the overall phyysical plan and
71+
/// Stage can complete on its own; it's likely holding a leaf node in the overall physical plan and
7272
/// producing data from a [`DataSourceExec`].
7373
#[derive(Debug, Clone)]
7474
pub struct Stage {

src/test_utils/localhost.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ use tokio::net::TcpListener;
1717
use tonic::transport::{Channel, Server};
1818
use url::Url;
1919

20+
/// Create workers and context on localhost with a fixed number of target partitions.
21+
///
22+
/// Creates `num_workers` listeners, all bound to a random OS decided port on `127.0.0.1`, then
23+
/// attaches a channel resolver that is aware of these addresses to `session_builder` and uses it
24+
/// to spawn a flight service behind each listener.
25+
///
26+
/// Returns a session context aware of these workers, and a join set of all spawned worker tasks.
2027
pub async fn start_localhost_context<B>(
2128
num_workers: usize,
2229
session_builder: B,

tests/distributed_aggregation.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
#[cfg(all(feature = "integration", test))]
22
mod tests {
3+
use datafusion::arrow::array::{Int32Array, StringArray};
4+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
5+
use datafusion::arrow::record_batch::RecordBatch;
36
use datafusion::arrow::util::pretty::pretty_format_batches;
47
use datafusion::physical_plan::{displayable, execute_stream};
58
use datafusion_distributed::test_utils::localhost::start_localhost_context;
69
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
10+
use datafusion_distributed::test_utils::session_context::register_temp_parquet_table;
711
use datafusion_distributed::{
812
DefaultSessionBuilder, DistributedConfig, apply_network_boundaries, assert_snapshot,
913
display_plan_ascii, distribute_plan,
1014
};
1115
use futures::TryStreamExt;
1216
use std::error::Error;
17+
use std::sync::Arc;
18+
use uuid::Uuid;
1319

1420
#[tokio::test]
1521
async fn distributed_aggregation() -> Result<(), Box<dyn Error>> {
@@ -149,4 +155,98 @@ mod tests {
149155

150156
Ok(())
151157
}
158+
159+
/// Test that multiple first_value() aggregations work correctly in distributed queries.
160+
// TODO: Once https://github.com/apache/datafusion/pull/18303 is merged, this test will lose
161+
// meaning, since the PR above will mask the underlying problem. Different queries or
162+
// a new approach must be used in this case.
163+
#[tokio::test]
164+
async fn test_multiple_first_value_aggregations() -> Result<(), Box<dyn Error>> {
165+
let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await;
166+
167+
let schema = Arc::new(Schema::new(vec![
168+
Field::new("group_id", DataType::Int32, false),
169+
Field::new("trace_id", DataType::Utf8, false),
170+
Field::new("value", DataType::Int32, false),
171+
]));
172+
173+
// Create 2 batches that will be stored as separate parquet files
174+
let batch1 = RecordBatch::try_new(
175+
schema.clone(),
176+
vec![
177+
Arc::new(Int32Array::from(vec![1, 2])),
178+
Arc::new(StringArray::from(vec!["trace1", "trace2"])),
179+
Arc::new(Int32Array::from(vec![100, 200])),
180+
],
181+
)?;
182+
183+
let batch2 = RecordBatch::try_new(
184+
schema.clone(),
185+
vec![
186+
Arc::new(Int32Array::from(vec![3, 4])),
187+
Arc::new(StringArray::from(vec!["trace3", "trace4"])),
188+
Arc::new(Int32Array::from(vec![300, 400])),
189+
],
190+
)?;
191+
192+
let file1 =
193+
register_temp_parquet_table("records_part1", schema.clone(), vec![batch1], &ctx)
194+
.await?;
195+
let file2 =
196+
register_temp_parquet_table("records_part2", schema.clone(), vec![batch2], &ctx)
197+
.await?;
198+
199+
// Create a partitioned table by registering multiple files
200+
let temp_dir = std::env::temp_dir();
201+
let table_dir = temp_dir.join(format!("partitioned_table_{}", Uuid::new_v4()));
202+
std::fs::create_dir(&table_dir)?;
203+
std::fs::copy(&file1, table_dir.join("part1.parquet"))?;
204+
std::fs::copy(&file2, table_dir.join("part2.parquet"))?;
205+
206+
// Register the directory as a partitioned table
207+
ctx.register_parquet(
208+
"records_partitioned",
209+
table_dir.to_str().unwrap(),
210+
datafusion::prelude::ParquetReadOptions::default(),
211+
)
212+
.await?;
213+
214+
let query = r#"SELECT group_id, first_value(trace_id) AS fv1, first_value(value) AS fv2
215+
FROM records_partitioned
216+
GROUP BY group_id
217+
ORDER BY group_id"#;
218+
219+
let df = ctx.sql(query).await?;
220+
let physical = df.create_physical_plan().await?;
221+
222+
let cfg = DistributedConfig::default().with_network_shuffle_tasks(2);
223+
let physical_distributed = apply_network_boundaries(physical, &cfg)?;
224+
let physical_distributed = distribute_plan(physical_distributed)?;
225+
226+
// Execute distributed query
227+
let batches_distributed = execute_stream(physical_distributed, ctx.task_ctx())?
228+
.try_collect::<Vec<_>>()
229+
.await?;
230+
231+
let actual_result = pretty_format_batches(&batches_distributed)?;
232+
let expected_result = "\
233+
+----------+--------+-----+
234+
| group_id | fv1 | fv2 |
235+
+----------+--------+-----+
236+
| 1 | trace1 | 100 |
237+
| 2 | trace2 | 200 |
238+
| 3 | trace3 | 300 |
239+
| 4 | trace4 | 400 |
240+
+----------+--------+-----+";
241+
242+
// Print them out, the error message from `assert_eq` is otherwise hard to read.
243+
println!("{}", expected_result);
244+
println!("{}", actual_result);
245+
246+
// Compare against result. The regression this is testing for would have NULL values in
247+
// the second and third column.
248+
assert_eq!(actual_result.to_string(), expected_result,);
249+
250+
Ok(())
251+
}
152252
}

0 commit comments

Comments
 (0)