Skip to content

Commit 456cafd

Browse files
flight_service: emit metrics from ArrowFlightEndpoint
This change updates the ArrowFlightEndpoint to collect metrics and emit them. When the last partition in a task is finished, the ArrowFlightEndpoint collects metrics and emits them via the TrailingFlightDataStream. Previously, we would determine if a partition is finished when the request first hit the endpoint. Now, we do it on stream completition. This is crutial for metrics collection because we need to know that the stream is exhausted, meaning that there's no data flowing in the plan and metrics are not actively being updated. Since the ArrowFlightEndpoint now emits metrics and NetworkBoundary plan nodes collect metrics, all coordinating StageExecs will now have the full collection of metrics for all tasks. This commit adds integration style tests that assert that the coordinator is recieving the full set of metrics. Finally, this change refactors and unskips some old metrics tests. These tests were skipped because the plans would change. Now, we use test utils to count the number of nodes etc to make these tests more resilient to cahnges. Follow up work - Only collect metrics if a configuration is set in the SessionContext, removing extra overhead - Display metrics in the plan using EXPLAIN (ANALYZE) - consider using sqllogictest or similar to test the output Closes: #94 Informs: #123
1 parent 0d912ba commit 456cafd

File tree

12 files changed

+524
-221
lines changed

12 files changed

+524
-221
lines changed

src/execution_plans/metrics.rs

Lines changed: 321 additions & 149 deletions
Large diffs are not rendered by default.

src/execution_plans/metrics_collecting_stream.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ where
6767
};
6868
metrics_collection.insert(stage_key, task_metrics.metrics);
6969
}
70-
7170
flight_data.app_metadata.clear();
71+
7272
Ok(())
7373
}
7474
}
@@ -256,7 +256,8 @@ mod tests {
256256
// Create a stream that emits an error - should be propagated through
257257
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
258258
let error_stream = stream::iter(vec![Err(stream_error)]);
259-
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
259+
let mut collecting_stream =
260+
MetricsCollectingStream::new(error_stream, metrics_collection.clone());
260261

261262
let result = collecting_stream.next().await.unwrap();
262263
assert_protocol_error(result, "stream error from inner stream");

src/execution_plans/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod network_shuffle;
55
mod partition_isolator;
66
mod stage;
77

8+
pub use metrics::collect_and_create_metrics_flight_data;
89
pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady};
910
pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec};
1011
pub use partition_isolator::PartitionIsolatorExec;

src/execution_plans/network_coalesce.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver;
33
use crate::common::scale_partitioning_props;
44
use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err};
6+
use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream;
67
use crate::execution_plans::{DistributedTaskContext, StageExec};
78
use crate::flight_service::DoGet;
89
use crate::metrics::proto::MetricsSetProto;
@@ -282,6 +283,7 @@ impl ExecutionPlan for NetworkCoalesceExec {
282283
return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed");
283284
};
284285

286+
let metrics_collection_capture = self_ready.metrics_collection.clone();
285287
let stream = async move {
286288
let channel = channel_resolver.get_channel_for_url(&url).await?;
287289
let stream = FlightServiceClient::new(channel)
@@ -291,8 +293,13 @@ impl ExecutionPlan for NetworkCoalesceExec {
291293
.into_inner()
292294
.map_err(|err| FlightError::Tonic(Box::new(err)));
293295

294-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
295-
.map_err(map_flight_to_datafusion_error))
296+
let metrics_collecting_stream =
297+
MetricsCollectingStream::new(stream, metrics_collection_capture);
298+
299+
Ok(
300+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
301+
.map_err(map_flight_to_datafusion_error),
302+
)
296303
}
297304
.try_flatten_stream();
298305

src/execution_plans/network_shuffle.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver;
33
use crate::common::scale_partitioning;
44
use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::NetworkBoundary;
6+
use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream;
67
use crate::execution_plans::{DistributedTaskContext, StageExec};
78
use crate::flight_service::DoGet;
89
use crate::metrics::proto::MetricsSetProto;
@@ -330,6 +331,7 @@ impl ExecutionPlan for NetworkShuffleExec {
330331
},
331332
);
332333

334+
let metrics_collection_capture = self_ready.metrics_collection.clone();
333335
async move {
334336
let url = task.url.ok_or(internal_datafusion_err!(
335337
"NetworkShuffleExec: task is unassigned, cannot proceed"
@@ -343,8 +345,13 @@ impl ExecutionPlan for NetworkShuffleExec {
343345
.into_inner()
344346
.map_err(|err| FlightError::Tonic(Box::new(err)));
345347

346-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
347-
.map_err(map_flight_to_datafusion_error))
348+
let metrics_collecting_stream =
349+
MetricsCollectingStream::new(stream, metrics_collection_capture);
350+
351+
Ok(
352+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
353+
.map_err(map_flight_to_datafusion_error),
354+
)
348355
}
349356
.try_flatten_stream()
350357
.boxed()

src/flight_service/do_get.rs

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use crate::common::with_callback;
22
use crate::config_extension_ext::ContextGrpcMetadata;
3-
use crate::execution_plans::{DistributedTaskContext, StageExec};
3+
use crate::execution_plans::{
4+
DistributedTaskContext, StageExec, collect_and_create_metrics_flight_data,
5+
};
46
use crate::flight_service::service::ArrowFlightEndpoint;
57
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
8+
use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream;
69
use crate::protobuf::{
710
DistributedCodec, StageKey, datafusion_error_to_tonic_status, stage_from_proto,
811
};
@@ -97,12 +100,6 @@ impl ArrowFlightEndpoint {
97100
})
98101
.await?;
99102
let stage = Arc::clone(&stage_data.stage);
100-
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);
101-
102-
// If all the partitions are done, remove the stage from the cache.
103-
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 {
104-
self.task_data_entries.remove(key);
105-
}
106103

107104
// Find out which partition group we are executing
108105
let cfg = session_state.config_mut();
@@ -130,24 +127,44 @@ impl ArrowFlightEndpoint {
130127
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
131128

132129
let schema = stream.schema();
130+
131+
// TODO: We don't need to do this since the stage / plan is captured again by the
132+
// TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
133+
// if we are running an `explain (analyze)` command. We should update this section
134+
// to only use one or the other - not both.
135+
let plan_capture = stage.plan.clone();
133136
let stream = with_callback(stream, move |_| {
134137
// We need to hold a reference to the plan for at least as long as the stream is
135138
// execution. Some plans might store state necessary for the stream to work, and
136139
// dropping the plan early could drop this state too soon.
137-
let _ = stage.plan;
140+
let _ = plan_capture;
138141
});
139142

140-
Ok(record_batch_stream_to_response(Box::pin(
141-
RecordBatchStreamAdapter::new(schema, stream),
142-
)))
143+
let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
144+
let task_data_capture = self.task_data_entries.clone();
145+
Ok(flight_stream_from_record_batch_stream(
146+
key.clone(),
147+
stage,
148+
stage_data.clone(),
149+
move || {
150+
task_data_capture.remove(key.clone());
151+
},
152+
record_batch_stream,
153+
))
143154
}
144155
}
145156

146157
fn missing(field: &'static str) -> impl FnOnce() -> Status {
147158
move || Status::invalid_argument(format!("Missing field '{field}'"))
148159
}
149160

150-
fn record_batch_stream_to_response(
161+
// Creates a tonic response from a stream of record batches. Handles
162+
// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
163+
fn flight_stream_from_record_batch_stream(
164+
stage_key: StageKey,
165+
stage: Arc<StageExec>,
166+
stage_data: TaskData,
167+
evict_stage: impl FnOnce() + Send + 'static,
151168
stream: SendableRecordBatchStream,
152169
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
153170
let flight_data_stream =
@@ -157,7 +174,31 @@ fn record_batch_stream_to_response(
157174
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
158175
}));
159176

160-
Response::new(Box::pin(flight_data_stream.map_err(|err| match err {
177+
let trailing_metrics_stream = TrailingFlightDataStream::new(
178+
move || {
179+
if stage_data
180+
.num_partitions_remaining
181+
.fetch_sub(1, Ordering::SeqCst)
182+
== 1
183+
{
184+
evict_stage();
185+
186+
let metrics_stream = collect_and_create_metrics_flight_data(stage_key, stage)
187+
.map_err(|err| {
188+
Status::internal(format!(
189+
"error collecting metrics in arrow flight endpoint: {err}"
190+
))
191+
})?;
192+
193+
return Ok(Some(metrics_stream));
194+
}
195+
196+
Ok(None)
197+
},
198+
flight_data_stream,
199+
);
200+
201+
Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err {
161202
FlightError::Tonic(status) => *status,
162203
_ => Status::internal(format!("Error during flight stream: {err}")),
163204
})))
@@ -228,24 +269,27 @@ mod tests {
228269
let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap();
229270
let stage_proto_for_closure = stage_proto.clone();
230271
let endpoint_ref = &endpoint;
272+
231273
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
232274
let stage_proto = stage_proto_for_closure.clone();
233-
// Create DoGet message
234275
let doget = DoGet {
235276
stage_proto: stage_proto.encode_to_vec().into(),
236277
target_task_index: task_number,
237278
target_partition: partition,
238279
stage_key: Some(stage_key),
239280
};
240281

241-
// Create Flight ticket
242282
let ticket = Ticket {
243283
ticket: Bytes::from(doget.encode_to_vec()),
244284
};
245285

246-
// Call the actual get() method
247286
let request = Request::new(ticket);
248-
endpoint_ref.get(request).await
287+
let response = endpoint_ref.get(request).await?;
288+
let mut stream = response.into_inner();
289+
290+
// Consume the stream.
291+
while let Some(_flight_data) = stream.try_next().await? {}
292+
Ok::<(), Status>(())
249293
};
250294

251295
// For each task, call do_get() for each partition except the last.
@@ -261,22 +305,22 @@ mod tests {
261305

262306
// Run the last partition of task 0. Any partition number works. Verify that the task state
263307
// is evicted because all partitions have been processed.
264-
let result = do_get(1, 0, task_keys[0].clone()).await;
308+
let result = do_get(2, 0, task_keys[0].clone()).await;
265309
assert!(result.is_ok());
266310
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
267311
assert_eq!(stored_stage_keys.len(), 2);
268312
assert!(stored_stage_keys.contains(&task_keys[1]));
269313
assert!(stored_stage_keys.contains(&task_keys[2]));
270314

271315
// Run the last partition of task 1.
272-
let result = do_get(1, 1, task_keys[1].clone()).await;
316+
let result = do_get(2, 1, task_keys[1].clone()).await;
273317
assert!(result.is_ok());
274318
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
275319
assert_eq!(stored_stage_keys.len(), 1);
276320
assert!(stored_stage_keys.contains(&task_keys[2]));
277321

278322
// Run the last partition of the last task.
279-
let result = do_get(1, 2, task_keys[2].clone()).await;
323+
let result = do_get(2, 2, task_keys[2].clone()).await;
280324
assert!(result.is_ok());
281325
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
282326
assert_eq!(stored_stage_keys.len(), 0);

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
mod trailing_flight_data_stream;
4+
pub(super) mod trailing_flight_data_stream;
55
pub(crate) use do_get::DoGet;
66

77
pub use service::ArrowFlightEndpoint;

src/flight_service/service.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming};
1717

1818
pub struct ArrowFlightEndpoint {
1919
pub(super) runtime: Arc<RuntimeEnv>,
20-
pub(super) task_data_entries: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
20+
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2121
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
2222
}
2323

@@ -28,7 +28,7 @@ impl ArrowFlightEndpoint {
2828
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
2929
Ok(Self {
3030
runtime: Arc::new(RuntimeEnv::default()),
31-
task_data_entries: ttl_map,
31+
task_data_entries: Arc::new(ttl_map),
3232
session_builder: Arc::new(session_builder),
3333
})
3434
}

src/flight_service/trailing_flight_data_stream.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@ use tokio::pin;
88
/// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished.
99
/// If the closure returns a new stream, it will be appended to the original stream and consumed.
1010
#[pin_project]
11-
pub struct TrailingFlightDataStream<S, F>
11+
pub struct TrailingFlightDataStream<S, T, F>
1212
where
1313
S: Stream<Item = Result<FlightData, FlightError>> + Send,
14-
F: FnOnce() -> Result<Option<S>, FlightError>,
14+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
15+
F: FnOnce() -> Result<Option<T>, FlightError>,
1516
{
1617
#[pin]
1718
inner: S,
1819
on_complete: Option<F>,
1920
#[pin]
20-
trailing_stream: Option<S>,
21+
trailing_stream: Option<T>,
2122
}
2223

23-
impl<S, F> TrailingFlightDataStream<S, F>
24+
impl<S, T, F> TrailingFlightDataStream<S, T, F>
2425
where
2526
S: Stream<Item = Result<FlightData, FlightError>> + Send,
26-
F: FnOnce() -> Result<Option<S>, FlightError>,
27+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
28+
F: FnOnce() -> Result<Option<T>, FlightError>,
2729
{
2830
// TODO: remove
2931
#[allow(dead_code)]
@@ -36,10 +38,11 @@ where
3638
}
3739
}
3840

39-
impl<S, F> Stream for TrailingFlightDataStream<S, F>
41+
impl<S, T, F> Stream for TrailingFlightDataStream<S, T, F>
4042
where
4143
S: Stream<Item = Result<FlightData, FlightError>> + Send,
42-
F: FnOnce() -> Result<Option<S>, FlightError>,
44+
T: Stream<Item = Result<FlightData, FlightError>> + Send,
45+
F: FnOnce() -> Result<Option<T>, FlightError>,
4346
{
4447
type Item = Result<FlightData, FlightError>;
4548

@@ -74,7 +77,7 @@ mod tests {
7477
use arrow::record_batch::RecordBatch;
7578
use arrow_flight::FlightData;
7679
use arrow_flight::decode::FlightRecordBatchStream;
77-
use arrow_flight::encode::FlightDataEncoderBuilder;
80+
use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder};
7881
use futures::stream::{self, StreamExt};
7982
use std::sync::Arc;
8083

@@ -186,7 +189,7 @@ mod tests {
186189
)))),
187190
];
188191
let inner_stream = stream::iter(data);
189-
let on_complete = || Ok(None);
192+
let on_complete = || Ok(None::<FlightDataEncoder>);
190193
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
191194
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
192195
.collect::<Vec<Result<RecordBatch, FlightError>>>()
@@ -202,8 +205,7 @@ mod tests {
202205
let name_array = StringArray::from(vec!["item1"]);
203206
let value_array = Int32Array::from(vec![1]);
204207
let inner_stream = create_flight_data_stream(name_array, value_array);
205-
206-
let on_complete = || -> Result<Option<_>, FlightError> {
208+
let on_complete = || -> Result<Option<FlightDataEncoder>, FlightError> {
207209
Err(FlightError::ExternalError(Box::new(std::io::Error::new(
208210
std::io::ErrorKind::Other,
209211
"callback error",
@@ -225,7 +227,7 @@ mod tests {
225227
StringArray::from(vec!["item1"] as Vec<&str>),
226228
Int32Array::from(vec![1] as Vec<i32>),
227229
);
228-
let on_complete = || Ok(None);
230+
let on_complete = || Ok(None::<FlightDataEncoder>);
229231
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
230232
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
231233
.collect::<Vec<Result<RecordBatch, FlightError>>>()

0 commit comments

Comments
 (0)