Skip to content

Commit 76cbed2

Browse files
committed
Merge branch 'main' into gabrielmusat/rework-stage-hierarchy-for-better-interoperability
# Conflicts: # src/common/mod.rs # src/execution_plans/network_coalesce.rs # src/execution_plans/network_shuffle.rs # src/flight_service/do_get.rs
2 parents e469202 + 5db1101 commit 76cbed2

File tree

9 files changed

+154
-433
lines changed

9 files changed

+154
-433
lines changed

src/common/callback_stream.rs

Lines changed: 0 additions & 84 deletions
This file was deleted.

src/common/map_last_stream.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use futures::{Stream, StreamExt, stream};
2+
use std::task::Poll;
3+
4+
/// Maps the last element of the provided stream.
5+
pub(crate) fn map_last_stream<T>(
6+
mut input: impl Stream<Item = T> + Unpin,
7+
map_f: impl FnOnce(T) -> T,
8+
) -> impl Stream<Item = T> + Unpin {
9+
let mut final_closure = Some(map_f);
10+
11+
// this is used to peek the new value so that we can map upon emitting the last message
12+
let mut current_value = None;
13+
14+
stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) {
15+
Some(new_val) => {
16+
match current_value.take() {
17+
// This is the first value, so we store it and repoll to get the next value
18+
None => {
19+
current_value = Some(new_val);
20+
cx.waker().wake_by_ref();
21+
Poll::Pending
22+
}
23+
24+
Some(existing) => {
25+
current_value = Some(new_val);
26+
27+
Poll::Ready(Some(existing))
28+
}
29+
}
30+
}
31+
// this is our last value, so we map it using the user provided closure
32+
None => match current_value.take() {
33+
Some(existing) => {
34+
// make sure we wake ourselves to finish the stream
35+
cx.waker().wake_by_ref();
36+
37+
if let Some(closure) = final_closure.take() {
38+
Poll::Ready(Some(closure(existing)))
39+
} else {
40+
unreachable!("the closure is only executed once")
41+
}
42+
}
43+
None => Poll::Ready(None),
44+
},
45+
})
46+
}
47+
48+
#[cfg(test)]
49+
mod tests {
50+
use super::*;
51+
use futures::stream;
52+
53+
#[tokio::test]
54+
async fn test_map_last_stream_empty_stream() {
55+
let input = stream::empty::<i32>();
56+
let mapped = map_last_stream(input, |x| x + 10);
57+
let result: Vec<i32> = mapped.collect().await;
58+
assert_eq!(result, Vec::<i32>::new());
59+
}
60+
61+
#[tokio::test]
62+
async fn test_map_last_stream_single_element() {
63+
let input = stream::iter(vec![5]);
64+
let mapped = map_last_stream(input, |x| x * 2);
65+
let result: Vec<i32> = mapped.collect().await;
66+
assert_eq!(result, vec![10]);
67+
}
68+
69+
#[tokio::test]
70+
async fn test_map_last_stream_multiple_elements() {
71+
let input = stream::iter(vec![1, 2, 3, 4]);
72+
let mapped = map_last_stream(input, |x| x + 100);
73+
let result: Vec<i32> = mapped.collect().await;
74+
assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed
75+
}
76+
77+
#[tokio::test]
78+
async fn test_map_last_stream_preserves_order() {
79+
let input = stream::iter(vec![10, 20, 30, 40, 50]);
80+
let mapped = map_last_stream(input, |x| x - 50);
81+
let result: Vec<i32> = mapped.collect().await;
82+
assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0
83+
}
84+
}

src/common/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
mod callback_stream;
21
mod execution_plan_ops;
2+
mod map_last_stream;
33
mod partitioning;
44
#[allow(unused)]
55
pub mod ttl_map;
66

7-
pub(crate) use callback_stream::with_callback;
87
pub(crate) use execution_plan_ops::*;
8+
pub(crate) use map_last_stream::map_last_stream;
99
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};

src/execution_plans/network_coalesce.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ use arrow_flight::error::FlightError;
1414
use bytes::Bytes;
1515
use dashmap::DashMap;
1616
use datafusion::common::{exec_err, internal_err, plan_err};
17+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1718
use datafusion::error::DataFusionError;
1819
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1920
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2021
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21-
use futures::{TryFutureExt, TryStreamExt};
22+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
2223
use http::Extensions;
2324
use prost::Message;
2425
use std::any::Any;
@@ -298,6 +299,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
298299
};
299300

300301
let metrics_collection_capture = self_ready.metrics_collection.clone();
302+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
303+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
301304
let stream = async move {
302305
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
303306
let stream = client
@@ -312,7 +315,12 @@ impl ExecutionPlan for NetworkCoalesceExec {
312315

313316
Ok(
314317
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
315-
.map_err(map_flight_to_datafusion_error),
318+
.map_err(map_flight_to_datafusion_error)
319+
.map(move |batch| {
320+
let batch = batch?;
321+
322+
mapper.map_batch(batch)
323+
}),
316324
)
317325
}
318326
.try_flatten_stream();

src/execution_plans/network_shuffle.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use arrow_flight::error::FlightError;
1515
use bytes::Bytes;
1616
use dashmap::DashMap;
1717
use datafusion::common::{exec_err, internal_datafusion_err, plan_err};
18+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1819
use datafusion::error::DataFusionError;
1920
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
2021
use datafusion::physical_expr::Partitioning;
@@ -328,8 +329,12 @@ impl ExecutionPlan for NetworkShuffleExec {
328329
let task_context = DistributedTaskContext::from_ctx(&context);
329330
let off = self_ready.properties.partitioning.partition_count() * task_context.task_index;
330331

332+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
333+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
334+
331335
let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
332336
let channel_resolver = Arc::clone(&channel_resolver);
337+
let mapper = mapper.clone();
333338

334339
let ticket = Request::from_parts(
335340
MetadataMap::from_headers(context_headers.clone()),
@@ -370,7 +375,12 @@ impl ExecutionPlan for NetworkShuffleExec {
370375

371376
Ok(
372377
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
373-
.map_err(map_flight_to_datafusion_error),
378+
.map_err(map_flight_to_datafusion_error)
379+
.map(move |batch| {
380+
let batch = batch?;
381+
382+
mapper.map_batch(batch)
383+
}),
374384
)
375385
}
376386
.try_flatten_stream()

0 commit comments

Comments
 (0)