Skip to content

Commit 5db1101

Browse files
cetra3gabotechs
andauthored
Fix Dictionary Encoded Values (#174)
* Adjsut the dictionary tracker for empty dictionary * Fix dictionaries in streams * Use map_last_stream.rs in favor of callback_stream.rs * Add `assert` for the batch schema --------- Co-authored-by: Gabriel Musat Mestre <[email protected]>
1 parent ce600dd commit 5db1101

File tree

9 files changed

+153
-432
lines changed

9 files changed

+153
-432
lines changed

src/common/callback_stream.rs

Lines changed: 0 additions & 84 deletions
This file was deleted.

src/common/map_last_stream.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use futures::{Stream, StreamExt, stream};
2+
use std::task::Poll;
3+
4+
/// Maps the last element of the provided stream.
5+
pub(crate) fn map_last_stream<T>(
6+
mut input: impl Stream<Item = T> + Unpin,
7+
map_f: impl FnOnce(T) -> T,
8+
) -> impl Stream<Item = T> + Unpin {
9+
let mut final_closure = Some(map_f);
10+
11+
// this is used to peek the new value so that we can map upon emitting the last message
12+
let mut current_value = None;
13+
14+
stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) {
15+
Some(new_val) => {
16+
match current_value.take() {
17+
// This is the first value, so we store it and repoll to get the next value
18+
None => {
19+
current_value = Some(new_val);
20+
cx.waker().wake_by_ref();
21+
Poll::Pending
22+
}
23+
24+
Some(existing) => {
25+
current_value = Some(new_val);
26+
27+
Poll::Ready(Some(existing))
28+
}
29+
}
30+
}
31+
// this is our last value, so we map it using the user provided closure
32+
None => match current_value.take() {
33+
Some(existing) => {
34+
// make sure we wake ourselves to finish the stream
35+
cx.waker().wake_by_ref();
36+
37+
if let Some(closure) = final_closure.take() {
38+
Poll::Ready(Some(closure(existing)))
39+
} else {
40+
unreachable!("the closure is only executed once")
41+
}
42+
}
43+
None => Poll::Ready(None),
44+
},
45+
})
46+
}
47+
48+
#[cfg(test)]
49+
mod tests {
50+
use super::*;
51+
use futures::stream;
52+
53+
#[tokio::test]
54+
async fn test_map_last_stream_empty_stream() {
55+
let input = stream::empty::<i32>();
56+
let mapped = map_last_stream(input, |x| x + 10);
57+
let result: Vec<i32> = mapped.collect().await;
58+
assert_eq!(result, Vec::<i32>::new());
59+
}
60+
61+
#[tokio::test]
62+
async fn test_map_last_stream_single_element() {
63+
let input = stream::iter(vec![5]);
64+
let mapped = map_last_stream(input, |x| x * 2);
65+
let result: Vec<i32> = mapped.collect().await;
66+
assert_eq!(result, vec![10]);
67+
}
68+
69+
#[tokio::test]
70+
async fn test_map_last_stream_multiple_elements() {
71+
let input = stream::iter(vec![1, 2, 3, 4]);
72+
let mapped = map_last_stream(input, |x| x + 100);
73+
let result: Vec<i32> = mapped.collect().await;
74+
assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed
75+
}
76+
77+
#[tokio::test]
78+
async fn test_map_last_stream_preserves_order() {
79+
let input = stream::iter(vec![10, 20, 30, 40, 50]);
80+
let mapped = map_last_stream(input, |x| x - 50);
81+
let result: Vec<i32> = mapped.collect().await;
82+
assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0
83+
}
84+
}

src/common/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
mod callback_stream;
1+
mod map_last_stream;
22
mod partitioning;
33
#[allow(unused)]
44
pub mod ttl_map;
55

6-
pub(crate) use callback_stream::with_callback;
6+
pub(crate) use map_last_stream::map_last_stream;
77
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};

src/execution_plans/network_coalesce.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
1515
use dashmap::DashMap;
1616
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
17+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1718
use datafusion::error::DataFusionError;
1819
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1920
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2021
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21-
use futures::{TryFutureExt, TryStreamExt};
22+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
2223
use http::Extensions;
2324
use prost::Message;
2425
use std::any::Any;
@@ -283,6 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
283284
};
284285

285286
let metrics_collection_capture = self_ready.metrics_collection.clone();
287+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
288+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
286289
let stream = async move {
287290
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
288291
let stream = client
@@ -297,7 +300,12 @@ impl ExecutionPlan for NetworkCoalesceExec {
297300

298301
Ok(
299302
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
300-
.map_err(map_flight_to_datafusion_error),
303+
.map_err(map_flight_to_datafusion_error)
304+
.map(move |batch| {
305+
let batch = batch?;
306+
307+
mapper.map_batch(batch)
308+
}),
301309
)
302310
}
303311
.try_flatten_stream();

src/execution_plans/network_shuffle.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
1515
use dashmap::DashMap;
1616
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
17+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1718
use datafusion::error::DataFusionError;
1819
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1920
use datafusion::physical_expr::Partitioning;
@@ -308,8 +309,12 @@ impl ExecutionPlan for NetworkShuffleExec {
308309
let task_context = DistributedTaskContext::from_ctx(&context);
309310
let off = self_ready.properties.partitioning.partition_count() * task_context.task_index;
310311

312+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
313+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
314+
311315
let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
312316
let channel_resolver = Arc::clone(&channel_resolver);
317+
let mapper = mapper.clone();
313318

314319
let ticket = Request::from_parts(
315320
MetadataMap::from_headers(context_headers.clone()),
@@ -349,7 +354,12 @@ impl ExecutionPlan for NetworkShuffleExec {
349354

350355
Ok(
351356
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
352-
.map_err(map_flight_to_datafusion_error),
357+
.map_err(map_flight_to_datafusion_error)
358+
.map(move |batch| {
359+
let batch = batch?;
360+
361+
mapper.map_batch(batch)
362+
}),
353363
)
354364
}
355365
.try_flatten_stream()

0 commit comments

Comments
 (0)