Skip to content

Commit 5129632

Browse files
flight_service: emit metrics from ArrowFlightEndpoint
This change updates the ArrowFlightEndpoint to collect metrics and emit them. When the last partition in a task is finished, the ArrowFlightEndpoint collects metrics and emits them via the TrailingFlightDataStream. Previously, we would determine if a partition is finished when the request first hit the endpoint. Now, we do it on stream completition. This is crutial for metrics collection because we need to know that the stream is exhausted, meaning that there's no data flowing in the plan and metrics are not actively being updated. Since the ArrowFlightEndpoint now emits metrics and NetworkBoundary plan nodes collect metrics, all coordinating StageExecs will now have the full collection of metrics for all tasks. This commit adds integration style tests that assert that the coordinator is recieving the full set of metrics. Follow up work - Only collect metrics if a configuration is set in the SessionContext, removing extra overhead - Display metrics in the plan using EXPLAIN (ANALYZE) - consider using sqllogictest or similar to test the output
1 parent 2df3467 commit 5129632

File tree

10 files changed

+439
-205
lines changed

10 files changed

+439
-205
lines changed

src/execution_plans/metrics.rs

Lines changed: 311 additions & 136 deletions
Large diffs are not rendered by default.

src/execution_plans/metrics_collecting_stream.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ where
6767
};
6868
metrics_collection.insert(stage_key, task_metrics.metrics);
6969
}
70-
7170
flight_data.app_metadata.clear();
71+
7272
Ok(())
7373
}
7474
}
@@ -256,7 +256,8 @@ mod tests {
256256
// Create a stream that emits an error - should be propagated through
257257
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
258258
let error_stream = stream::iter(vec![Err(stream_error)]);
259-
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
259+
let mut collecting_stream =
260+
MetricsCollectingStream::new(error_stream, metrics_collection.clone());
260261

261262
let result = collecting_stream.next().await.unwrap();
262263
assert_protocol_error(result, "stream error from inner stream");

src/execution_plans/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod network_shuffle;
55
mod partition_isolator;
66
mod stage;
77

8+
pub use metrics::collect_and_create_metrics_flight_data;
89
pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady};
910
pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec};
1011
pub use partition_isolator::PartitionIsolatorExec;

src/execution_plans/network_coalesce.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver;
33
use crate::common::scale_partitioning_props;
44
use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err};
6+
use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream;
67
use crate::execution_plans::{DistributedTaskContext, StageExec};
78
use crate::flight_service::DoGet;
89
use crate::metrics::proto::MetricsSetProto;
@@ -297,6 +298,7 @@ impl ExecutionPlan for NetworkCoalesceExec {
297298
return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed");
298299
};
299300

301+
let metrics_collection_capture = self_ready.metrics_collection.clone();
300302
let stream = async move {
301303
let channel = channel_resolver.get_channel_for_url(&url).await?;
302304
let stream = FlightServiceClient::new(channel)
@@ -306,8 +308,13 @@ impl ExecutionPlan for NetworkCoalesceExec {
306308
.into_inner()
307309
.map_err(|err| FlightError::Tonic(Box::new(err)));
308310

309-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
310-
.map_err(map_flight_to_datafusion_error))
311+
let metrics_collecting_stream =
312+
MetricsCollectingStream::new(stream, metrics_collection_capture);
313+
314+
Ok(
315+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
316+
.map_err(map_flight_to_datafusion_error),
317+
)
311318
}
312319
.try_flatten_stream();
313320

src/execution_plans/network_shuffle.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver;
33
use crate::common::scale_partitioning;
44
use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::NetworkBoundary;
6+
use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream;
67
use crate::execution_plans::{DistributedTaskContext, StageExec};
78
use crate::flight_service::DoGet;
89
use crate::metrics::proto::MetricsSetProto;
@@ -330,6 +331,7 @@ impl ExecutionPlan for NetworkShuffleExec {
330331
},
331332
);
332333

334+
let metrics_collection_capture = self_ready.metrics_collection.clone();
333335
async move {
334336
let url = task.url.ok_or(internal_datafusion_err!(
335337
"NetworkShuffleExec: task is unassigned, cannot proceed"
@@ -343,8 +345,13 @@ impl ExecutionPlan for NetworkShuffleExec {
343345
.into_inner()
344346
.map_err(|err| FlightError::Tonic(Box::new(err)));
345347

346-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
347-
.map_err(map_flight_to_datafusion_error))
348+
let metrics_collecting_stream =
349+
MetricsCollectingStream::new(stream, metrics_collection_capture);
350+
351+
Ok(
352+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
353+
.map_err(map_flight_to_datafusion_error),
354+
)
348355
}
349356
.try_flatten_stream()
350357
.boxed()

src/flight_service/do_get.rs

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use crate::config_extension_ext::ContextGrpcMetadata;
2-
use crate::execution_plans::{DistributedTaskContext, StageExec};
2+
use crate::execution_plans::{
3+
DistributedTaskContext, StageExec, collect_and_create_metrics_flight_data,
4+
};
35
use crate::flight_service::service::ArrowFlightEndpoint;
46
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
7+
use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream;
58
use crate::protobuf::{
69
DistributedCodec, StageExecProto, StageKey, datafusion_error_to_tonic_status, stage_from_proto,
710
};
@@ -94,12 +97,6 @@ impl ArrowFlightEndpoint {
9497
})
9598
.await?;
9699
let stage = Arc::clone(&stage_data.stage);
97-
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);
98-
99-
// If all the partitions are done, remove the stage from the cache.
100-
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 {
101-
self.task_data_entries.remove(key);
102-
}
103100

104101
// Find out which partition group we are executing
105102
let cfg = session_state.config_mut();
@@ -126,15 +123,30 @@ impl ArrowFlightEndpoint {
126123
.execute(doget.target_partition as usize, session_state.task_ctx())
127124
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
128125

129-
Ok(record_batch_stream_to_response(stream))
126+
let task_data_capture = self.task_data_entries.clone();
127+
Ok(flight_stream_from_record_batch_stream(
128+
key.clone(),
129+
stage,
130+
stage_data.clone(),
131+
move || {
132+
task_data_capture.remove(key.clone());
133+
},
134+
stream,
135+
))
130136
}
131137
}
132138

133139
fn missing(field: &'static str) -> impl FnOnce() -> Status {
134140
move || Status::invalid_argument(format!("Missing field '{field}'"))
135141
}
136142

137-
fn record_batch_stream_to_response(
143+
// Creates a tonic response from a stream of record batches. Handles
144+
// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
145+
fn flight_stream_from_record_batch_stream(
146+
stage_key: StageKey,
147+
stage: Arc<StageExec>,
148+
stage_data: TaskData,
149+
evict_stage: impl FnOnce() + Send + 'static,
138150
stream: SendableRecordBatchStream,
139151
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
140152
let flight_data_stream =
@@ -144,7 +156,31 @@ fn record_batch_stream_to_response(
144156
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
145157
}));
146158

147-
Response::new(Box::pin(flight_data_stream.map_err(|err| match err {
159+
let trailing_metrics_stream = TrailingFlightDataStream::new(
160+
move || {
161+
if stage_data
162+
.num_partitions_remaining
163+
.fetch_sub(1, Ordering::SeqCst)
164+
== 1
165+
{
166+
evict_stage();
167+
168+
let metrics_stream = collect_and_create_metrics_flight_data(stage_key, stage)
169+
.map_err(|err| {
170+
Status::internal(format!(
171+
"error collecting metrics in arrow flight endpoint: {err}"
172+
))
173+
})?;
174+
175+
return Ok(Some(metrics_stream));
176+
}
177+
178+
Ok(None)
179+
},
180+
flight_data_stream,
181+
);
182+
183+
Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err {
148184
FlightError::Tonic(status) => *status,
149185
_ => Status::internal(format!("Error during flight stream: {err}")),
150186
})))
@@ -215,24 +251,27 @@ mod tests {
215251
let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap();
216252
let stage_proto_for_closure = stage_proto.clone();
217253
let endpoint_ref = &endpoint;
254+
218255
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
219256
let stage_proto = stage_proto_for_closure.clone();
220-
// Create DoGet message
221257
let doget = DoGet {
222258
stage_proto: Some(stage_proto),
223259
target_task_index: task_number,
224260
target_partition: partition,
225261
stage_key: Some(stage_key),
226262
};
227263

228-
// Create Flight ticket
229264
let ticket = Ticket {
230265
ticket: Bytes::from(doget.encode_to_vec()),
231266
};
232267

233-
// Call the actual get() method
234268
let request = Request::new(ticket);
235-
endpoint_ref.get(request).await
269+
let response = endpoint_ref.get(request).await?;
270+
let mut stream = response.into_inner();
271+
272+
// Consume the stream.
273+
while let Some(_flight_data) = stream.try_next().await? {}
274+
Ok::<(), Status>(())
236275
};
237276

238277
// For each task, call do_get() for each partition except the last.
@@ -248,22 +287,22 @@ mod tests {
248287

249288
// Run the last partition of task 0. Any partition number works. Verify that the task state
250289
// is evicted because all partitions have been processed.
251-
let result = do_get(1, 0, task_keys[0].clone()).await;
290+
let result = do_get(2, 0, task_keys[0].clone()).await;
252291
assert!(result.is_ok());
253292
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
254293
assert_eq!(stored_stage_keys.len(), 2);
255294
assert!(stored_stage_keys.contains(&task_keys[1]));
256295
assert!(stored_stage_keys.contains(&task_keys[2]));
257296

258297
// Run the last partition of task 1.
259-
let result = do_get(1, 1, task_keys[1].clone()).await;
298+
let result = do_get(2, 1, task_keys[1].clone()).await;
260299
assert!(result.is_ok());
261300
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
262301
assert_eq!(stored_stage_keys.len(), 1);
263302
assert!(stored_stage_keys.contains(&task_keys[2]));
264303

265304
// Run the last partition of the last task.
266-
let result = do_get(1, 2, task_keys[2].clone()).await;
305+
let result = do_get(2, 2, task_keys[2].clone()).await;
267306
assert!(result.is_ok());
268307
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
269308
assert_eq!(stored_stage_keys.len(), 0);

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
mod trailing_flight_data_stream;
4+
pub(super) mod trailing_flight_data_stream;
55
pub(crate) use do_get::DoGet;
66

77
pub use service::ArrowFlightEndpoint;

src/flight_service/service.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming};
1717

1818
pub struct ArrowFlightEndpoint {
1919
pub(super) runtime: Arc<RuntimeEnv>,
20-
pub(super) task_data_entries: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
20+
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2121
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
2222
}
2323

@@ -28,7 +28,7 @@ impl ArrowFlightEndpoint {
2828
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
2929
Ok(Self {
3030
runtime: Arc::new(RuntimeEnv::default()),
31-
task_data_entries: ttl_map,
31+
task_data_entries: Arc::new(ttl_map),
3232
session_builder: Arc::new(session_builder),
3333
})
3434
}

src/flight_service/trailing_flight_data_stream.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@ use tokio::pin;
88
/// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished.
99
/// If the closure returns a new stream, it will be appended to the original stream and consumed.
1010
#[pin_project]
11-
pub struct TrailingFlightDataStream<S, F>
11+
pub struct TrailingFlightDataStream<S, T, F>
1212
where
1313
S: Stream<Item = Result<FlightData, FlightError>> + Send,
14-
F: FnOnce() -> Result<Option<S>, FlightError>,
14+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
15+
F: FnOnce() -> Result<Option<T>, FlightError>,
1516
{
1617
#[pin]
1718
inner: S,
1819
on_complete: Option<F>,
1920
#[pin]
20-
trailing_stream: Option<S>,
21+
trailing_stream: Option<T>,
2122
}
2223

23-
impl<S, F> TrailingFlightDataStream<S, F>
24+
impl<S, T, F> TrailingFlightDataStream<S, T, F>
2425
where
2526
S: Stream<Item = Result<FlightData, FlightError>> + Send,
26-
F: FnOnce() -> Result<Option<S>, FlightError>,
27+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
28+
F: FnOnce() -> Result<Option<T>, FlightError>,
2729
{
2830
// TODO: remove
2931
#[allow(dead_code)]
@@ -36,10 +38,11 @@ where
3638
}
3739
}
3840

39-
impl<S, F> Stream for TrailingFlightDataStream<S, F>
41+
impl<S, T, F> Stream for TrailingFlightDataStream<S, T, F>
4042
where
4143
S: Stream<Item = Result<FlightData, FlightError>> + Send,
42-
F: FnOnce() -> Result<Option<S>, FlightError>,
44+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
45+
F: FnOnce() -> Result<Option<T>, FlightError>,
4346
{
4447
type Item = Result<FlightData, FlightError>;
4548

@@ -74,7 +77,7 @@ mod tests {
7477
use arrow::record_batch::RecordBatch;
7578
use arrow_flight::FlightData;
7679
use arrow_flight::decode::FlightRecordBatchStream;
77-
use arrow_flight::encode::FlightDataEncoderBuilder;
80+
use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder};
7881
use futures::stream::{self, StreamExt};
7982
use std::sync::Arc;
8083

@@ -186,7 +189,7 @@ mod tests {
186189
)))),
187190
];
188191
let inner_stream = stream::iter(data);
189-
let on_complete = || Ok(None);
192+
let on_complete = || Ok(None::<FlightDataEncoder>);
190193
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
191194
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
192195
.collect::<Vec<Result<RecordBatch, FlightError>>>()
@@ -202,8 +205,7 @@ mod tests {
202205
let name_array = StringArray::from(vec!["item1"]);
203206
let value_array = Int32Array::from(vec![1]);
204207
let inner_stream = create_flight_data_stream(name_array, value_array);
205-
206-
let on_complete = || -> Result<Option<_>, FlightError> {
208+
let on_complete = || -> Result<Option<FlightDataEncoder>, FlightError> {
207209
Err(FlightError::ExternalError(Box::new(std::io::Error::new(
208210
std::io::ErrorKind::Other,
209211
"callback error",
@@ -225,7 +227,7 @@ mod tests {
225227
StringArray::from(vec!["item1"] as Vec<&str>),
226228
Int32Array::from(vec![1] as Vec<i32>),
227229
);
228-
let on_complete = || Ok(None);
230+
let on_complete = || Ok(None::<FlightDataEncoder>);
229231
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
230232
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
231233
.collect::<Vec<Result<RecordBatch, FlightError>>>()

0 commit comments

Comments
 (0)