Skip to content

Commit 624228b

Browse files
committed
Fix dictionaries in streams
1 parent 576db69 commit 624228b

File tree

6 files changed

+105
-302
lines changed

6 files changed

+105
-302
lines changed

src/execution_plans/network_coalesce.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
1515
use dashmap::DashMap;
1616
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
17+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1718
use datafusion::error::DataFusionError;
1819
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1920
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2021
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21-
use futures::{TryFutureExt, TryStreamExt};
22+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
2223
use http::Extensions;
2324
use prost::Message;
2425
use std::any::Any;
@@ -283,6 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
283284
};
284285

285286
let metrics_collection_capture = self_ready.metrics_collection.clone();
287+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
288+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
286289
let stream = async move {
287290
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
288291
let stream = client
@@ -297,7 +300,12 @@ impl ExecutionPlan for NetworkCoalesceExec {
297300

298301
Ok(
299302
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
300-
.map_err(map_flight_to_datafusion_error),
303+
.map_err(map_flight_to_datafusion_error)
304+
.map(move |batch| {
305+
let batch = batch?;
306+
307+
mapper.map_batch(batch)
308+
}),
301309
)
302310
}
303311
.try_flatten_stream();

src/execution_plans/network_shuffle.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
1515
use dashmap::DashMap;
1616
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
17+
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
1718
use datafusion::error::DataFusionError;
1819
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1920
use datafusion::physical_expr::Partitioning;
@@ -308,8 +309,12 @@ impl ExecutionPlan for NetworkShuffleExec {
308309
let task_context = DistributedTaskContext::from_ctx(&context);
309310
let off = self_ready.properties.partitioning.partition_count() * task_context.task_index;
310311

312+
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
313+
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
314+
311315
let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
312316
let channel_resolver = Arc::clone(&channel_resolver);
317+
let mapper = mapper.clone();
313318

314319
let ticket = Request::from_parts(
315320
MetadataMap::from_headers(context_headers.clone()),
@@ -349,7 +354,12 @@ impl ExecutionPlan for NetworkShuffleExec {
349354

350355
Ok(
351356
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
352-
.map_err(map_flight_to_datafusion_error),
357+
.map_err(map_flight_to_datafusion_error)
358+
.map(move |batch| {
359+
let batch = batch?;
360+
361+
mapper.map_batch(batch)
362+
}),
353363
)
354364
}
355365
.try_flatten_stream()

src/flight_service/do_get.rs

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use crate::config_extension_ext::ContextGrpcMetadata;
33
use crate::execution_plans::{DistributedTaskContext, StageExec};
44
use crate::flight_service::service::ArrowFlightEndpoint;
55
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
6-
use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream;
76
use crate::metrics::TaskMetricsCollector;
87
use crate::metrics::proto::df_metrics_set_to_proto;
98
use crate::protobuf::{
@@ -16,18 +15,17 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
1615
use arrow_flight::error::FlightError;
1716
use arrow_flight::flight_service_server::FlightService;
1817
use bytes::Bytes;
19-
use datafusion::arrow::array::RecordBatch;
20-
use datafusion::arrow::datatypes::SchemaRef;
21-
use datafusion::arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
18+
2219
use datafusion::common::exec_datafusion_err;
2320
use datafusion::execution::SendableRecordBatchStream;
2421
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2522
use datafusion::prelude::SessionContext;
26-
use futures::TryStreamExt;
27-
use futures::{Stream, stream};
23+
use futures::stream;
24+
use futures::{StreamExt, TryStreamExt};
2825
use prost::Message;
2926
use std::sync::Arc;
3027
use std::sync::atomic::{AtomicUsize, Ordering};
28+
use std::task::Poll;
3129
use tonic::{Request, Response, Status};
3230

3331
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -173,51 +171,82 @@ fn flight_stream_from_record_batch_stream(
173171
evict_stage: impl FnOnce() + Send + 'static,
174172
stream: SendableRecordBatchStream,
175173
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
176-
let flight_data_stream =
174+
let mut flight_data_stream =
177175
FlightDataEncoderBuilder::new()
178176
.with_schema(stream.schema().clone())
179177
.build(stream.map_err(|err| {
180178
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
181179
}));
182180

183-
let trailing_metrics_stream = TrailingFlightDataStream::new(
184-
move || {
185-
if stage_data
186-
.num_partitions_remaining
187-
.fetch_sub(1, Ordering::SeqCst)
188-
== 1
189-
{
190-
evict_stage();
191-
192-
let metrics_stream =
193-
collect_and_create_metrics_flight_data(stage_key, stage_data.stage).map_err(
194-
|err| {
195-
Status::internal(format!(
196-
"error collecting metrics in arrow flight endpoint: {err}"
197-
))
198-
},
199-
)?;
200-
201-
return Ok(Some(metrics_stream));
202-
}
203-
204-
Ok(None)
205-
},
206-
flight_data_stream,
207-
);
181+
// executed once when the stream ends
182+
// decorates the last flight data with our metrics
183+
let mut final_closure = Some(move |last_flight_data| {
184+
if stage_data
185+
.num_partitions_remaining
186+
.fetch_sub(1, Ordering::SeqCst)
187+
== 1
188+
{
189+
evict_stage();
190+
191+
collect_and_create_metrics_flight_data(stage_key, stage_data.stage, last_flight_data)
192+
} else {
193+
Ok(last_flight_data)
194+
}
195+
});
196+
197+
// this is used to peek the new value
198+
// so that we can add our metrics to the last flight data
199+
let mut current_value = None;
200+
201+
let stream =
202+
stream::poll_fn(
203+
move |cx| match futures::ready!(flight_data_stream.poll_next_unpin(cx)) {
204+
Some(Ok(new_val)) => {
205+
match current_value.take() {
206+
// This is the first value, so we store it and repoll to get the next value
207+
None => {
208+
current_value = Some(new_val);
209+
cx.waker().wake_by_ref();
210+
Poll::Pending
211+
}
212+
213+
Some(existing) => {
214+
current_value = Some(new_val);
215+
216+
Poll::Ready(Some(Ok(existing)))
217+
}
218+
}
219+
}
220+
// this is our last value, so we add our metrics to this flight data
221+
None => match current_value.take() {
222+
Some(existing) => {
223+
// make sure we wake ourselves to finish the stream
224+
cx.waker().wake_by_ref();
225+
226+
if let Some(closure) = final_closure.take() {
227+
Poll::Ready(Some(closure(existing)))
228+
} else {
229+
unreachable!("the closure is only executed once")
230+
}
231+
}
232+
None => Poll::Ready(None),
233+
},
234+
err => Poll::Ready(err),
235+
},
236+
);
208237

209-
Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err {
238+
Response::new(Box::pin(stream.map_err(|err| match err {
210239
FlightError::Tonic(status) => *status,
211240
_ => Status::internal(format!("Error during flight stream: {err}")),
212241
})))
213242
}
214243

215-
// Collects metrics from the provided stage and encodes it into a stream of flight data using
216-
// the schema of the stage.
244+
/// Collects metrics from the provided stage and includes it in the flight data
217245
fn collect_and_create_metrics_flight_data(
218246
stage_key: StageKey,
219247
stage: Arc<StageExec>,
220-
) -> Result<impl Stream<Item = Result<FlightData, FlightError>> + Send + 'static, FlightError> {
248+
incoming: FlightData,
249+
) -> Result<FlightData, FlightError> {
221250
// Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks.
222251
let mut result = TaskMetricsCollector::new()
223252
.collect(stage.plan.clone())
@@ -252,35 +281,12 @@ fn collect_and_create_metrics_flight_data(
252281
})),
253282
};
254283

255-
let metrics_flight_data =
256-
empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?;
257-
Ok(Box::pin(stream::once(
258-
async move { Ok(metrics_flight_data) },
259-
)))
260-
}
261-
262-
/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema.
263-
/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder])
264-
/// since they skip messages with empty RecordBatch data.
265-
pub fn empty_flight_data_with_app_metadata(
266-
metadata: FlightAppMetadata,
267-
schema: SchemaRef,
268-
) -> Result<FlightData, FlightError> {
269284
let mut buf = vec![];
270-
metadata
285+
flight_app_metadata
271286
.encode(&mut buf)
272287
.map_err(|err| FlightError::ProtocolError(err.to_string()))?;
273288

274-
let empty_batch = RecordBatch::new_empty(schema);
275-
let options = IpcWriteOptions::default();
276-
let data_gen = IpcDataGenerator::default();
277-
let mut dictionary_tracker = DictionaryTracker::new(false);
278-
let (_, encoded_data) = data_gen
279-
.encoded_batch(&empty_batch, &mut dictionary_tracker, &options)
280-
.map_err(|e| {
281-
FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}"))
282-
})?;
283-
Ok(FlightData::from(encoded_data).with_app_metadata(buf))
289+
Ok(incoming.with_app_metadata(buf))
284290
}
285291

286292
#[cfg(test)]

src/flight_service/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
pub(super) mod trailing_flight_data_stream;
54
pub(crate) use do_get::DoGet;
65

76
pub use service::ArrowFlightEndpoint;

0 commit comments

Comments
 (0)