|
1 | 1 | #[cfg(all(feature = "integration", test))] |
2 | 2 | mod tests { |
| 3 | + use datafusion::arrow::array::{Int32Array, StringArray}; |
| 4 | + use datafusion::arrow::datatypes::{DataType, Field, Schema}; |
| 5 | + use datafusion::arrow::record_batch::RecordBatch; |
3 | 6 | use datafusion::arrow::util::pretty::pretty_format_batches; |
4 | 7 | use datafusion::physical_plan::{displayable, execute_stream}; |
5 | 8 | use datafusion_distributed::test_utils::localhost::start_localhost_context; |
6 | 9 | use datafusion_distributed::test_utils::parquet::register_parquet_tables; |
| 10 | + use datafusion_distributed::test_utils::session_context::register_temp_parquet_table; |
7 | 11 | use datafusion_distributed::{ |
8 | 12 | DefaultSessionBuilder, DistributedConfig, apply_network_boundaries, assert_snapshot, |
9 | 13 | display_plan_ascii, distribute_plan, |
10 | 14 | }; |
11 | 15 | use futures::TryStreamExt; |
12 | 16 | use std::error::Error; |
| 17 | + use std::sync::Arc; |
| 18 | + use uuid::Uuid; |
13 | 19 |
|
14 | 20 | #[tokio::test] |
15 | 21 | async fn distributed_aggregation() -> Result<(), Box<dyn Error>> { |
@@ -149,4 +155,98 @@ mod tests { |
149 | 155 |
|
150 | 156 | Ok(()) |
151 | 157 | } |
| 158 | + |
| 159 | + /// Test that multiple first_value() aggregations work correctly in distributed queries. |
| 160 | + // TODO: Once https://github.com/apache/datafusion/pull/18303 is merged, this test will lose |
| 161 | + // meaning, since the PR above will mask the underlying problem. Different queries or |
| 162 | + // a new approach must be used in this case. |
| 163 | + #[tokio::test] |
| 164 | + async fn test_multiple_first_value_aggregations() -> Result<(), Box<dyn Error>> { |
| 165 | + let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; |
| 166 | + |
| 167 | + let schema = Arc::new(Schema::new(vec![ |
| 168 | + Field::new("group_id", DataType::Int32, false), |
| 169 | + Field::new("trace_id", DataType::Utf8, false), |
| 170 | + Field::new("value", DataType::Int32, false), |
| 171 | + ])); |
| 172 | + |
| 173 | + // Create 2 batches that will be stored as separate parquet files |
| 174 | + let batch1 = RecordBatch::try_new( |
| 175 | + schema.clone(), |
| 176 | + vec![ |
| 177 | + Arc::new(Int32Array::from(vec![1, 2])), |
| 178 | + Arc::new(StringArray::from(vec!["trace1", "trace2"])), |
| 179 | + Arc::new(Int32Array::from(vec![100, 200])), |
| 180 | + ], |
| 181 | + )?; |
| 182 | + |
| 183 | + let batch2 = RecordBatch::try_new( |
| 184 | + schema.clone(), |
| 185 | + vec![ |
| 186 | + Arc::new(Int32Array::from(vec![3, 4])), |
| 187 | + Arc::new(StringArray::from(vec!["trace3", "trace4"])), |
| 188 | + Arc::new(Int32Array::from(vec![300, 400])), |
| 189 | + ], |
| 190 | + )?; |
| 191 | + |
| 192 | + let file1 = |
| 193 | + register_temp_parquet_table("records_part1", schema.clone(), vec![batch1], &ctx) |
| 194 | + .await?; |
| 195 | + let file2 = |
| 196 | + register_temp_parquet_table("records_part2", schema.clone(), vec![batch2], &ctx) |
| 197 | + .await?; |
| 198 | + |
| 199 | + // Create a partitioned table by registering multiple files |
| 200 | + let temp_dir = std::env::temp_dir(); |
| 201 | + let table_dir = temp_dir.join(format!("partitioned_table_{}", Uuid::new_v4())); |
| 202 | + std::fs::create_dir(&table_dir)?; |
| 203 | + std::fs::copy(&file1, table_dir.join("part1.parquet"))?; |
| 204 | + std::fs::copy(&file2, table_dir.join("part2.parquet"))?; |
| 205 | + |
| 206 | + // Register the directory as a partitioned table |
| 207 | + ctx.register_parquet( |
| 208 | + "records_partitioned", |
| 209 | + table_dir.to_str().unwrap(), |
| 210 | + datafusion::prelude::ParquetReadOptions::default(), |
| 211 | + ) |
| 212 | + .await?; |
| 213 | + |
| 214 | + let query = r#"SELECT group_id, first_value(trace_id) AS fv1, first_value(value) AS fv2 |
| 215 | + FROM records_partitioned |
| 216 | + GROUP BY group_id |
| 217 | + ORDER BY group_id"#; |
| 218 | + |
| 219 | + let df = ctx.sql(query).await?; |
| 220 | + let physical = df.create_physical_plan().await?; |
| 221 | + |
| 222 | + let cfg = DistributedConfig::default().with_network_shuffle_tasks(2); |
| 223 | + let physical_distributed = apply_network_boundaries(physical, &cfg)?; |
| 224 | + let physical_distributed = distribute_plan(physical_distributed)?; |
| 225 | + |
| 226 | + // Execute distributed query |
| 227 | + let batches_distributed = execute_stream(physical_distributed, ctx.task_ctx())? |
| 228 | + .try_collect::<Vec<_>>() |
| 229 | + .await?; |
| 230 | + |
| 231 | + let actual_result = pretty_format_batches(&batches_distributed)?; |
| 232 | + let expected_result = "\ |
| 233 | ++----------+--------+-----+ |
| 234 | +| group_id | fv1 | fv2 | |
| 235 | ++----------+--------+-----+ |
| 236 | +| 1 | trace1 | 100 | |
| 237 | +| 2 | trace2 | 200 | |
| 238 | +| 3 | trace3 | 300 | |
| 239 | +| 4 | trace4 | 400 | |
| 240 | ++----------+--------+-----+"; |
| 241 | + |
| 242 | + // Print them out, the error message from `assert_eq` is otherwise hard to read. |
| 243 | + println!("{}", expected_result); |
| 244 | + println!("{}", actual_result); |
| 245 | + |
| 246 | + // Compare against result. The regression this is testing for would have NULL values in |
| 247 | + // the second and third column. |
| 248 | + assert_eq!(actual_result.to_string(), expected_result,); |
| 249 | + |
| 250 | + Ok(()) |
| 251 | + } |
152 | 252 | } |
0 commit comments