Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use uuid::Uuid;
/// The receiving ArrowFlightEndpoint will then execute Stage 2 and will repeat this process.
///
/// When Stage 4 is executed, it has no input tasks, so it is assumed that the plan included in that
/// Stage can complete on its own; its likely holding a leaf node in the overall phyysical plan and
/// Stage can complete on its own; it's likely holding a leaf node in the overall physical plan and
/// producing data from a [`DataSourceExec`].
#[derive(Debug, Clone)]
pub struct Stage {
Expand Down
7 changes: 7 additions & 0 deletions src/test_utils/localhost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ use tokio::net::TcpListener;
use tonic::transport::{Channel, Server};
use url::Url;

/// Create workers and context on localhost with a fixed number of target partitions.
///
/// Creates `num_workers` listeners, all bound to a random OS decided port on `127.0.0.1`, then
/// attaches a channel resolver that is aware of these addresses to `session_builder` and uses it
/// to spawn a flight service behind each listener.
///
/// Returns a session context aware of these workers, and a join set of all spawned worker tasks.
pub async fn start_localhost_context<B>(
num_workers: usize,
session_builder: B,
Expand Down
100 changes: 100 additions & 0 deletions tests/distributed_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#[cfg(all(feature = "integration", test))]
mod tests {
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::physical_plan::{displayable, execute_stream};
use datafusion_distributed::test_utils::localhost::start_localhost_context;
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
use datafusion_distributed::test_utils::session_context::register_temp_parquet_table;
use datafusion_distributed::{
DefaultSessionBuilder, DistributedConfig, apply_network_boundaries, assert_snapshot,
display_plan_ascii, distribute_plan,
};
use futures::TryStreamExt;
use std::error::Error;
use std::sync::Arc;
use uuid::Uuid;

#[tokio::test]
async fn distributed_aggregation() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -149,4 +155,98 @@ mod tests {

Ok(())
}

/// Test that multiple first_value() aggregations work correctly in distributed queries.
// TODO: Once https://github.com/apache/datafusion/pull/18303 is merged, this test will lose
// meaning, since the PR above will mask the underlying problem. Different queries or
// a new approach must be used in this case.
Comment on lines +159 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just check for duplicate column names directly then? We could make a MemoryExec or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into that as a follow-up? I want to get the remaining issues with distributed activation first, this can be a side-quest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure yes we'll see what the maintainers here think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. It's still a good test. If you could file and issue and put the number in the todo, that would help us track it :)

#[tokio::test]
async fn test_multiple_first_value_aggregations() -> Result<(), Box<dyn Error>> {
let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await;

let schema = Arc::new(Schema::new(vec![
Field::new("group_id", DataType::Int32, false),
Field::new("trace_id", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
]));

// Create 2 batches that will be stored as separate parquet files
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["trace1", "trace2"])),
Arc::new(Int32Array::from(vec![100, 200])),
],
)?;

let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![3, 4])),
Arc::new(StringArray::from(vec!["trace3", "trace4"])),
Arc::new(Int32Array::from(vec![300, 400])),
],
)?;

let file1 =
register_temp_parquet_table("records_part1", schema.clone(), vec![batch1], &ctx)
.await?;
let file2 =
register_temp_parquet_table("records_part2", schema.clone(), vec![batch2], &ctx)
.await?;

// Create a partitioned table by registering multiple files
let temp_dir = std::env::temp_dir();
let table_dir = temp_dir.join(format!("partitioned_table_{}", Uuid::new_v4()));
std::fs::create_dir(&table_dir)?;
std::fs::copy(&file1, table_dir.join("part1.parquet"))?;
std::fs::copy(&file2, table_dir.join("part2.parquet"))?;

// Register the directory as a partitioned table
ctx.register_parquet(
"records_partitioned",
table_dir.to_str().unwrap(),
datafusion::prelude::ParquetReadOptions::default(),
)
.await?;

let query = r#"SELECT group_id, first_value(trace_id) AS fv1, first_value(value) AS fv2
FROM records_partitioned
GROUP BY group_id
ORDER BY group_id"#;

let df = ctx.sql(query).await?;
let physical = df.create_physical_plan().await?;

let cfg = DistributedConfig::default().with_network_shuffle_tasks(2);
let physical_distributed = apply_network_boundaries(physical, &cfg)?;
let physical_distributed = distribute_plan(physical_distributed)?;

// Execute distributed query
let batches_distributed = execute_stream(physical_distributed, ctx.task_ctx())?
.try_collect::<Vec<_>>()
.await?;

let actual_result = pretty_format_batches(&batches_distributed)?;
let expected_result = "\
+----------+--------+-----+
| group_id | fv1 | fv2 |
+----------+--------+-----+
| 1 | trace1 | 100 |
| 2 | trace2 | 200 |
| 3 | trace3 | 300 |
| 4 | trace4 | 400 |
+----------+--------+-----+";

// Print them out, the error message from `assert_eq` is otherwise hard to read.
println!("{}", expected_result);
println!("{}", actual_result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be preferable to only do this on failure. I think it's fine if we do

if actual_result.to_string() != expected_result {
    println!("{}", expected_result);
    println!("{}", actual_result);
    panic!(...)
}


// Compare against result. The regression this is testing for would have NULL values in
// the second and third column.
assert_eq!(actual_result.to_string(), expected_result,);

Ok(())
}
}