Skip to content

Commit 3fe4b08

Browse files
flight_service: emit metrics from ArrowFlightEndpoint (#160)
This change updates the ArrowFlightEndpoint to collect metrics and emit them. When the last partition in a task is finished, the ArrowFlightEndpoint collects metrics and emits them via the TrailingFlightDataStream. Previously, we would determine if a partition is finished when the request first hit the endpoint. Now, we do it on stream completition. This is crutial for metrics collection because we need to know that the stream is exhausted, meaning that there's no data flowing in the plan and metrics are not actively being updated. Since the ArrowFlightEndpoint now emits metrics and NetworkBoundary plan nodes collect metrics, all coordinating StageExecs will now have the full collection of metrics for all tasks. This commit adds integration style tests that assert that the coordinator is recieving the full set of metrics. Finally, this change refactors and unskips some old metrics tests. These tests were skipped because the plans would change. Now, we use test utils to count the number of nodes etc to make these tests more resilient to cahnges. Follow up work - Only collect metrics if a configuration is set in the SessionContext, removing extra overhead - Display metrics in the plan using EXPLAIN (ANALYZE) - consider using sqllogictest or similar to test the output Closes: #94 Informs: #123
1 parent 0d912ba commit 3fe4b08

File tree

15 files changed

+882
-485
lines changed

15 files changed

+882
-485
lines changed

src/execution_plans/metrics.rs

Lines changed: 6 additions & 405 deletions
Large diffs are not rendered by default.

src/execution_plans/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
mod metrics;
2-
mod metrics_collecting_stream;
32
mod network_coalesce;
43
mod network_shuffle;
54
mod partition_isolator;
65
mod stage;
76

7+
pub use metrics::MetricsWrapperExec;
88
pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady};
99
pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec};
1010
pub use partition_isolator::PartitionIsolatorExec;

src/execution_plans/network_coalesce.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err};
66
use crate::execution_plans::{DistributedTaskContext, StageExec};
77
use crate::flight_service::DoGet;
8+
use crate::metrics::MetricsCollectingStream;
89
use crate::metrics::proto::MetricsSetProto;
910
use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage};
1011
use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error};
@@ -282,6 +283,7 @@ impl ExecutionPlan for NetworkCoalesceExec {
282283
return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed");
283284
};
284285

286+
let metrics_collection_capture = self_ready.metrics_collection.clone();
285287
let stream = async move {
286288
let channel = channel_resolver.get_channel_for_url(&url).await?;
287289
let stream = FlightServiceClient::new(channel)
@@ -291,8 +293,13 @@ impl ExecutionPlan for NetworkCoalesceExec {
291293
.into_inner()
292294
.map_err(|err| FlightError::Tonic(Box::new(err)));
293295

294-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
295-
.map_err(map_flight_to_datafusion_error))
296+
let metrics_collecting_stream =
297+
MetricsCollectingStream::new(stream, metrics_collection_capture);
298+
299+
Ok(
300+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
301+
.map_err(map_flight_to_datafusion_error),
302+
)
296303
}
297304
.try_flatten_stream();
298305

src/execution_plans/network_shuffle.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::config_extension_ext::ContextGrpcMetadata;
55
use crate::distributed_physical_optimizer_rule::NetworkBoundary;
66
use crate::execution_plans::{DistributedTaskContext, StageExec};
77
use crate::flight_service::DoGet;
8+
use crate::metrics::MetricsCollectingStream;
89
use crate::metrics::proto::MetricsSetProto;
910
use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage};
1011
use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error};
@@ -330,6 +331,7 @@ impl ExecutionPlan for NetworkShuffleExec {
330331
},
331332
);
332333

334+
let metrics_collection_capture = self_ready.metrics_collection.clone();
333335
async move {
334336
let url = task.url.ok_or(internal_datafusion_err!(
335337
"NetworkShuffleExec: task is unassigned, cannot proceed"
@@ -343,8 +345,13 @@ impl ExecutionPlan for NetworkShuffleExec {
343345
.into_inner()
344346
.map_err(|err| FlightError::Tonic(Box::new(err)));
345347

346-
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
347-
.map_err(map_flight_to_datafusion_error))
348+
let metrics_collecting_stream =
349+
MetricsCollectingStream::new(stream, metrics_collection_capture);
350+
351+
Ok(
352+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
353+
.map_err(map_flight_to_datafusion_error),
354+
)
348355
}
349356
.try_flatten_stream()
350357
.boxed()

src/flight_service/do_get.rs

Lines changed: 141 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@ use crate::config_extension_ext::ContextGrpcMetadata;
33
use crate::execution_plans::{DistributedTaskContext, StageExec};
44
use crate::flight_service::service::ArrowFlightEndpoint;
55
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
6+
use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream;
7+
use crate::metrics::TaskMetricsCollector;
8+
use crate::metrics::proto::df_metrics_set_to_proto;
69
use crate::protobuf::{
7-
DistributedCodec, StageKey, datafusion_error_to_tonic_status, stage_from_proto,
10+
AppMetadata, DistributedCodec, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
11+
datafusion_error_to_tonic_status, stage_from_proto,
812
};
13+
use arrow::array::RecordBatch;
14+
use arrow::datatypes::SchemaRef;
15+
use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
16+
use arrow_flight::FlightData;
917
use arrow_flight::Ticket;
1018
use arrow_flight::encode::FlightDataEncoderBuilder;
1119
use arrow_flight::error::FlightError;
@@ -15,6 +23,7 @@ use datafusion::common::exec_datafusion_err;
1523
use datafusion::execution::SendableRecordBatchStream;
1624
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1725
use futures::TryStreamExt;
26+
use futures::{Stream, stream};
1827
use prost::Message;
1928
use std::sync::Arc;
2029
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -97,12 +106,6 @@ impl ArrowFlightEndpoint {
97106
})
98107
.await?;
99108
let stage = Arc::clone(&stage_data.stage);
100-
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);
101-
102-
// If all the partitions are done, remove the stage from the cache.
103-
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 {
104-
self.task_data_entries.remove(key);
105-
}
106109

107110
// Find out which partition group we are executing
108111
let cfg = session_state.config_mut();
@@ -130,24 +133,42 @@ impl ArrowFlightEndpoint {
130133
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
131134

132135
let schema = stream.schema();
136+
137+
// TODO: We don't need to do this since the stage / plan is captured again by the
138+
// TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
139+
// if we are running an `explain (analyze)` command. We should update this section
140+
// to only use one or the other - not both.
141+
let plan_capture = stage.plan.clone();
133142
let stream = with_callback(stream, move |_| {
134143
// We need to hold a reference to the plan for at least as long as the stream is
135144
// execution. Some plans might store state necessary for the stream to work, and
136145
// dropping the plan early could drop this state too soon.
137-
let _ = stage.plan;
146+
let _ = plan_capture;
138147
});
139148

140-
Ok(record_batch_stream_to_response(Box::pin(
141-
RecordBatchStreamAdapter::new(schema, stream),
142-
)))
149+
let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
150+
let task_data_capture = self.task_data_entries.clone();
151+
Ok(flight_stream_from_record_batch_stream(
152+
key.clone(),
153+
stage_data.clone(),
154+
move || {
155+
task_data_capture.remove(key.clone());
156+
},
157+
record_batch_stream,
158+
))
143159
}
144160
}
145161

146162
fn missing(field: &'static str) -> impl FnOnce() -> Status {
147163
move || Status::invalid_argument(format!("Missing field '{field}'"))
148164
}
149165

150-
fn record_batch_stream_to_response(
166+
/// Creates a tonic response from a stream of record batches. Handles
167+
/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
168+
fn flight_stream_from_record_batch_stream(
169+
stage_key: StageKey,
170+
stage_data: TaskData,
171+
evict_stage: impl FnOnce() + Send + 'static,
151172
stream: SendableRecordBatchStream,
152173
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
153174
let flight_data_stream =
@@ -157,12 +178,109 @@ fn record_batch_stream_to_response(
157178
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
158179
}));
159180

160-
Response::new(Box::pin(flight_data_stream.map_err(|err| match err {
181+
let trailing_metrics_stream = TrailingFlightDataStream::new(
182+
move || {
183+
if stage_data
184+
.num_partitions_remaining
185+
.fetch_sub(1, Ordering::SeqCst)
186+
== 1
187+
{
188+
evict_stage();
189+
190+
let metrics_stream =
191+
collect_and_create_metrics_flight_data(stage_key, stage_data.stage).map_err(
192+
|err| {
193+
Status::internal(format!(
194+
"error collecting metrics in arrow flight endpoint: {err}"
195+
))
196+
},
197+
)?;
198+
199+
return Ok(Some(metrics_stream));
200+
}
201+
202+
Ok(None)
203+
},
204+
flight_data_stream,
205+
);
206+
207+
Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err {
161208
FlightError::Tonic(status) => *status,
162209
_ => Status::internal(format!("Error during flight stream: {err}")),
163210
})))
164211
}
165212

213+
// Collects metrics from the provided stage and encodes it into a stream of flight data using
214+
// the schema of the stage.
215+
fn collect_and_create_metrics_flight_data(
216+
stage_key: StageKey,
217+
stage: Arc<StageExec>,
218+
) -> Result<impl Stream<Item = Result<FlightData, FlightError>> + Send + 'static, FlightError> {
219+
// Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks.
220+
let mut result = TaskMetricsCollector::new()
221+
.collect(stage.plan.clone())
222+
.map_err(|err| FlightError::ProtocolError(err.to_string()))?;
223+
224+
// Add the metrics for this task into the collection of task metrics.
225+
// Skip any metrics that can't be converted to proto (unsupported types)
226+
let proto_task_metrics = result
227+
.task_metrics
228+
.iter()
229+
.map(|metrics| {
230+
df_metrics_set_to_proto(metrics)
231+
.map_err(|err| FlightError::ProtocolError(err.to_string()))
232+
})
233+
.collect::<Result<Vec<_>, _>>()?;
234+
result
235+
.input_task_metrics
236+
.insert(stage_key, proto_task_metrics);
237+
238+
// Serialize the metrics for all tasks.
239+
let mut task_metrics_set = vec![];
240+
for (stage_key, metrics) in result.input_task_metrics.into_iter() {
241+
task_metrics_set.push(TaskMetrics {
242+
stage_key: Some(stage_key),
243+
metrics,
244+
});
245+
}
246+
247+
let flight_app_metadata = FlightAppMetadata {
248+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
249+
tasks: task_metrics_set,
250+
})),
251+
};
252+
253+
let metrics_flight_data =
254+
empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?;
255+
Ok(Box::pin(stream::once(
256+
async move { Ok(metrics_flight_data) },
257+
)))
258+
}
259+
260+
/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema.
261+
/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder])
262+
/// since they skip messages with empty RecordBatch data.
263+
pub fn empty_flight_data_with_app_metadata(
264+
metadata: FlightAppMetadata,
265+
schema: SchemaRef,
266+
) -> Result<FlightData, FlightError> {
267+
let mut buf = vec![];
268+
metadata
269+
.encode(&mut buf)
270+
.map_err(|err| FlightError::ProtocolError(err.to_string()))?;
271+
272+
let empty_batch = RecordBatch::new_empty(schema);
273+
let options = IpcWriteOptions::default();
274+
let data_gen = IpcDataGenerator::default();
275+
let mut dictionary_tracker = DictionaryTracker::new(true);
276+
let (_, encoded_data) = data_gen
277+
.encoded_batch(&empty_batch, &mut dictionary_tracker, &options)
278+
.map_err(|e| {
279+
FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}"))
280+
})?;
281+
Ok(FlightData::from(encoded_data).with_app_metadata(buf))
282+
}
283+
166284
#[cfg(test)]
167285
mod tests {
168286
use super::*;
@@ -228,24 +346,27 @@ mod tests {
228346
let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap();
229347
let stage_proto_for_closure = stage_proto.clone();
230348
let endpoint_ref = &endpoint;
349+
231350
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
232351
let stage_proto = stage_proto_for_closure.clone();
233-
// Create DoGet message
234352
let doget = DoGet {
235353
stage_proto: stage_proto.encode_to_vec().into(),
236354
target_task_index: task_number,
237355
target_partition: partition,
238356
stage_key: Some(stage_key),
239357
};
240358

241-
// Create Flight ticket
242359
let ticket = Ticket {
243360
ticket: Bytes::from(doget.encode_to_vec()),
244361
};
245362

246-
// Call the actual get() method
247363
let request = Request::new(ticket);
248-
endpoint_ref.get(request).await
364+
let response = endpoint_ref.get(request).await?;
365+
let mut stream = response.into_inner();
366+
367+
// Consume the stream.
368+
while let Some(_flight_data) = stream.try_next().await? {}
369+
Ok::<(), Status>(())
249370
};
250371

251372
// For each task, call do_get() for each partition except the last.
@@ -261,22 +382,22 @@ mod tests {
261382

262383
// Run the last partition of task 0. Any partition number works. Verify that the task state
263384
// is evicted because all partitions have been processed.
264-
let result = do_get(1, 0, task_keys[0].clone()).await;
385+
let result = do_get(2, 0, task_keys[0].clone()).await;
265386
assert!(result.is_ok());
266387
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
267388
assert_eq!(stored_stage_keys.len(), 2);
268389
assert!(stored_stage_keys.contains(&task_keys[1]));
269390
assert!(stored_stage_keys.contains(&task_keys[2]));
270391

271392
// Run the last partition of task 1.
272-
let result = do_get(1, 1, task_keys[1].clone()).await;
393+
let result = do_get(2, 1, task_keys[1].clone()).await;
273394
assert!(result.is_ok());
274395
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
275396
assert_eq!(stored_stage_keys.len(), 1);
276397
assert!(stored_stage_keys.contains(&task_keys[2]));
277398

278399
// Run the last partition of the last task.
279-
let result = do_get(1, 2, task_keys[2].clone()).await;
400+
let result = do_get(2, 2, task_keys[2].clone()).await;
280401
assert!(result.is_ok());
281402
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
282403
assert_eq!(stored_stage_keys.len(), 0);

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
mod trailing_flight_data_stream;
4+
pub(super) mod trailing_flight_data_stream;
55
pub(crate) use do_get::DoGet;
66

77
pub use service::ArrowFlightEndpoint;

src/flight_service/service.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming};
1717

1818
pub struct ArrowFlightEndpoint {
1919
pub(super) runtime: Arc<RuntimeEnv>,
20-
pub(super) task_data_entries: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
20+
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2121
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
2222
}
2323

@@ -28,7 +28,7 @@ impl ArrowFlightEndpoint {
2828
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
2929
Ok(Self {
3030
runtime: Arc::new(RuntimeEnv::default()),
31-
task_data_entries: ttl_map,
31+
task_data_entries: Arc::new(ttl_map),
3232
session_builder: Arc::new(session_builder),
3333
})
3434
}

0 commit comments

Comments
 (0)