Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/execution_plans/network_coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use dashmap::DashMap;
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{TryFutureExt, TryStreamExt};
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use prost::Message;
use std::any::Any;
Expand Down Expand Up @@ -283,6 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
};

let metrics_collection_capture = self_ready.metrics_collection.clone();
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like 1:1 schema mapping. What does it do? Is this just a way to assert that the schema hasn't changed? I think adding a test which shows why this is necessary would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema does change. The Arrow Flight data hydrates dictionary values as real values, and so the schema of the incoming recordbatch is different. We use the mapper here to map back to what the execution plan expects

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that tests still pass without this line.

IIUC, the root problem was on the server - we were sending an empty flight data to the client without sending the schema / dictionary message first. You've fixed this problem.

I don't see an issue on the client that this solves. The flight decoder in the client should be able to handle any message sent by the encoder on the server.

The metrics collector on the client passes through flight data unchanged, minus clearing the app_metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to either have a test which shows why this is needed or remove the lines. Lmk if you think otherwise though!

Once again, I appreciate the contribution 🙏🏽 - the old empty flight data code was sketchy for sure.

Copy link
Contributor Author

@cetra3 cetra3 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_metrics_collection_e2e_4 fails with this removed from both the network plans

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I commented one but not the other. This LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an assert here to make sure the schema matches: a141a3b

let stream = async move {
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
let stream = client
Expand All @@ -297,7 +300,12 @@ impl ExecutionPlan for NetworkCoalesceExec {

Ok(
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
.map_err(map_flight_to_datafusion_error),
.map_err(map_flight_to_datafusion_error)
.map(move |batch| {
let batch = batch?;

mapper.map_batch(batch)
}),
)
}
.try_flatten_stream();
Expand Down
12 changes: 11 additions & 1 deletion src/execution_plans/network_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use dashmap::DashMap;
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::Partitioning;
Expand Down Expand Up @@ -308,8 +309,12 @@ impl ExecutionPlan for NetworkShuffleExec {
let task_context = DistributedTaskContext::from_ctx(&context);
let off = self_ready.properties.partitioning.partition_count() * task_context.task_index;

let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
let (mapper, _indices) = adapter.map_schema(&self.schema())?;

let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
let channel_resolver = Arc::clone(&channel_resolver);
let mapper = mapper.clone();

let ticket = Request::from_parts(
MetadataMap::from_headers(context_headers.clone()),
Expand Down Expand Up @@ -349,7 +354,12 @@ impl ExecutionPlan for NetworkShuffleExec {

Ok(
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
.map_err(map_flight_to_datafusion_error),
.map_err(map_flight_to_datafusion_error)
.map(move |batch| {
let batch = batch?;

mapper.map_batch(batch)
}),
)
}
.try_flatten_stream()
Expand Down
128 changes: 67 additions & 61 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::config_extension_ext::ContextGrpcMetadata;
use crate::execution_plans::{DistributedTaskContext, StageExec};
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream;
use crate::metrics::TaskMetricsCollector;
use crate::metrics::proto::df_metrics_set_to_proto;
use crate::protobuf::{
Expand All @@ -16,18 +15,17 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use bytes::Bytes;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};

use datafusion::common::exec_datafusion_err;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use futures::{Stream, stream};
use futures::stream;
use futures::{StreamExt, TryStreamExt};
use prost::Message;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Poll;
use tonic::{Request, Response, Status};

#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down Expand Up @@ -173,51 +171,82 @@ fn flight_stream_from_record_batch_stream(
evict_stage: impl FnOnce() + Send + 'static,
stream: SendableRecordBatchStream,
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
let flight_data_stream =
let mut flight_data_stream =
FlightDataEncoderBuilder::new()
.with_schema(stream.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));

let trailing_metrics_stream = TrailingFlightDataStream::new(
move || {
if stage_data
.num_partitions_remaining
.fetch_sub(1, Ordering::SeqCst)
== 1
{
evict_stage();

let metrics_stream =
collect_and_create_metrics_flight_data(stage_key, stage_data.stage).map_err(
|err| {
Status::internal(format!(
"error collecting metrics in arrow flight endpoint: {err}"
))
},
)?;

return Ok(Some(metrics_stream));
}

Ok(None)
},
flight_data_stream,
);
// executed once when the stream ends
// decorates the last flight data with our metrics
let mut final_closure = Some(move |last_flight_data| {
if stage_data
.num_partitions_remaining
.fetch_sub(1, Ordering::SeqCst)
== 1
{
evict_stage();

collect_and_create_metrics_flight_data(stage_key, stage_data.stage, last_flight_data)
} else {
Ok(last_flight_data)
}
});

// this is used to peek the new value
// so that we can add our metrics to the last flight data
let mut current_value = None;

let stream =
stream::poll_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic could be factored out to a map_last_stream method that is not specific to Arrow Flight or anything. Just a function that takes a stream and a FnOnce closure and returns a stream whose last element gets applied the provided FnOnce.

If we had this generic stream mapping, that would allow us to remove the callback_stream.rs file that does something similar, but with a callback that does not mutate the last element.

That might be too much of a refactor for this single PR though... it's probably better done in a new one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, here's a suggestion about how we can use what this PR brings to remove the old callback_stream.rs:

pydantic#1

Let me know if that makes sense.

move |cx| match futures::ready!(flight_data_stream.poll_next_unpin(cx)) {
Some(Ok(new_val)) => {
match current_value.take() {
// This is the first value, so we store it and repoll to get the next value
None => {
current_value = Some(new_val);
cx.waker().wake_by_ref();
Poll::Pending
}

Some(existing) => {
current_value = Some(new_val);

Poll::Ready(Some(Ok(existing)))
}
}
}
// this is our last value, so we add our metrics to this flight data
None => match current_value.take() {
Some(existing) => {
// make sure we wake ourselves to finish the stream
cx.waker().wake_by_ref();

if let Some(closure) = final_closure.take() {
Poll::Ready(Some(closure(existing)))
} else {
unreachable!("the closure is only executed once")
}
}
None => Poll::Ready(None),
},
err => Poll::Ready(err),
},
);

Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err {
Response::new(Box::pin(stream.map_err(|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
})))
}

// Collects metrics from the provided stage and encodes it into a stream of flight data using
// the schema of the stage.
/// Collects metrics from the provided stage and includes it in the flight data
fn collect_and_create_metrics_flight_data(
stage_key: StageKey,
stage: Arc<StageExec>,
) -> Result<impl Stream<Item = Result<FlightData, FlightError>> + Send + 'static, FlightError> {
incoming: FlightData,
) -> Result<FlightData, FlightError> {
// Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks.
let mut result = TaskMetricsCollector::new()
.collect(stage.plan.clone())
Expand Down Expand Up @@ -252,35 +281,12 @@ fn collect_and_create_metrics_flight_data(
})),
};

let metrics_flight_data =
empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?;
Ok(Box::pin(stream::once(
async move { Ok(metrics_flight_data) },
)))
}

/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema.
/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder])
/// since they skip messages with empty RecordBatch data.
pub fn empty_flight_data_with_app_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overall looking really good! I have one question purely for my understanding:

Before, we were appending the Arrow Flight stream with one extra message containing the app_metadata with the metrics, and no actual body (empty RecordBatch)

Now, we are intercepting the last message in the Arrow Flight stream, and we are enriching it with the app_metadata

I see that without these changes, your new test fails with the following error:

Protocol error: Failed to create empty batch FlightData: Ipc error: no dict id for field company

But I don't fully understand why, as I'd have expected the "Before" and "Now" approach to be equivalent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, based on my understanding, when encoding arrow flight data, there are 3 main types:

  • Schemas
  • Record Batches
  • Dictionaries

When the stream starts, the schema is encoded and sent. At that point (when we encode & send the schema) the dictionary tracker is loaded and populated based upon what dictionaries are present within fields. Then we get to the record batches. If any of them have fields with dictionary types, we use the dictionary tracker to determine if we have sent the dictionary already, and then send that if not, along with the record batch itself.

In the old code, we weren't encoding the schema with the empty record batch, so the dictionary tracker was never populated with the dictionaries. This meant that when it came time to encode the record batch (& possibly dictionary) it checks the schema of the record batch, and, determines a field has a dictionary, so consults the tracker as to what to do.

But because we weren't encoding the schema, the dictionary tracker is in a bad state, and so you get the no dict id for field error. I tried side-stepping this problem by encoding the schema first, to sort out the tracker, and throwing away the result. However: to make matters worse, the arrow flight encoder removes all dictionaries by default, hydrating them to their underlying values. This meant that we essentially had two different schemas for the one stream.

So you can fix it by asking the flight data encoder the schema it has and using that to encode the empty batch, but then if you ever want to change whether to hydrate dictionaries in the future or not it'd be broken. So I figured a simpler fix is what is proposed: don't bother writing out a empty record batch, just append the data to the last value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋🏽 Thanks for the contribution! Sending the metadata with the last message is more optimal for sure. I'm reviewing this PR now.

metadata: FlightAppMetadata,
schema: SchemaRef,
) -> Result<FlightData, FlightError> {
let mut buf = vec![];
metadata
flight_app_metadata
.encode(&mut buf)
.map_err(|err| FlightError::ProtocolError(err.to_string()))?;

let empty_batch = RecordBatch::new_empty(schema);
let options = IpcWriteOptions::default();
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(true);
let (_, encoded_data) = data_gen
.encoded_batch(&empty_batch, &mut dictionary_tracker, &options)
.map_err(|e| {
FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}"))
})?;
Ok(FlightData::from(encoded_data).with_app_metadata(buf))
Ok(incoming.with_app_metadata(buf))
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion src/flight_service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod do_get;
mod service;
mod session_builder;
pub(super) mod trailing_flight_data_stream;
pub(crate) use do_get::DoGet;

pub use service::ArrowFlightEndpoint;
Expand Down
Loading