Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ parquet = { version = "55.2.0", optional = true }
arrow = { version = "55.2.0", optional = true }
tokio-stream = { version = "0.1.17", optional = true }
hyper-util = { version = "0.1.16", optional = true }
pin-project = "1.1.10"

[features]
integration = [
Expand Down
43 changes: 3 additions & 40 deletions src/execution_plans/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,6 @@ impl ExecutionPlan for MetricsWrapperExec {

#[cfg(test)]
mod tests {
use crate::metrics::proto::{
ElapsedCompute, EndTimestamp, MetricProto, MetricValueProto, OutputRows, StartTimestamp,
};

use super::*;
use datafusion::arrow::array::{Int32Array, StringArray};
Expand All @@ -283,6 +280,7 @@ mod tests {
use crate::DistributedExt;
use crate::DistributedPhysicalOptimizerRule;
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
use crate::test_utils::session_context::register_temp_parquet_table;
use datafusion::execution::{SessionStateBuilder, context::SessionContext};
use datafusion::physical_plan::metrics::MetricValue;
Expand Down Expand Up @@ -367,47 +365,12 @@ mod tests {
(stage_exec, ctx)
}

/// creates a "distinct" set of metrics from the provided seed
fn make_distinct_metrics_set(seed: u64) -> MetricsSetProto {
const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC
MetricsSetProto {
metrics: vec![
MetricProto {
metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })),
labels: vec![],
partition: None,
},
MetricProto {
metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute {
value: seed,
})),
labels: vec![],
partition: None,
},
MetricProto {
metric: Some(MetricValueProto::StartTimestamp(StartTimestamp {
value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)),
})),
labels: vec![],
partition: None,
},
MetricProto {
metric: Some(MetricValueProto::EndTimestamp(EndTimestamp {
value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)),
})),
labels: vec![],
partition: None,
},
],
}
}

#[tokio::test]
#[ignore]
async fn test_metrics_rewriter() {
let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await;
let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec
.map(|i| make_distinct_metrics_set(i + 10))
.map(|i| make_test_metrics_set_proto_from_seed(i + 10))
.collect::<Vec<MetricsSetProto>>();

let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone());
Expand All @@ -433,7 +396,7 @@ mod tests {
#[tokio::test]
#[ignore]
async fn test_metrics_rewriter_correct_number_of_metrics() {
let test_metrics_set = make_distinct_metrics_set(10);
let test_metrics_set = make_test_metrics_set_proto_from_seed(10);
let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await;
let task_plan = executable_plan
.as_any()
Expand Down
264 changes: 264 additions & 0 deletions src/execution_plans/metrics_collecting_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
use std::pin::Pin;
use std::task::{Context, Poll};

use crate::metrics::proto::MetricsSetProto;
use crate::protobuf::StageKey;
use crate::protobuf::{AppMetadata, FlightAppMetadata};
use arrow_flight::{FlightData, error::FlightError};
use dashmap::DashMap;
use futures::stream::Stream;
use pin_project::pin_project;
use prost::Message;
use std::sync::Arc;

/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata
/// while passing through all the other FlightData unchanged.
#[pin_project]
pub struct MetricsCollectingStream<S>
where
S: Stream<Item = Result<FlightData, FlightError>> + Send,
{
#[pin]
inner: S,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should instead be:

#[pin_project]
pub struct MetricsCollectingStream<S>
where
    S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
{
    #[pin]
    inner: S,

I don't know the details about when this is needed, but I usually see wrapping streams be implemented this way in Rust.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 Will look into it.

metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
}

impl<S> MetricsCollectingStream<S>
where
S: Stream<Item = Result<FlightData, FlightError>> + Send,
{
#[allow(dead_code)]
pub fn new(
stream: S,
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
) -> Self {
Self {
inner: stream,
metrics_collection,
}
}

fn extract_metrics_from_flight_data(
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
flight_data: &mut FlightData,
) -> Result<(), FlightError> {
if flight_data.app_metadata.is_empty() {
return Ok(());
}

let metadata =
FlightAppMetadata::decode(flight_data.app_metadata.as_ref()).map_err(|e| {
FlightError::ProtocolError(format!("failed to decode app_metadata: {}", e))
})?;

let Some(content) = metadata.content else {
return Err(FlightError::ProtocolError(
"expected Some content in app_metadata".to_string(),
));
};

let AppMetadata::MetricsCollection(task_metrics_set) = content;

for task_metrics in task_metrics_set.tasks {
let Some(stage_key) = task_metrics.stage_key else {
return Err(FlightError::ProtocolError(
"expected Some StageKey in MetricsCollectingStream, got None".to_string(),
));
};
metrics_collection.insert(stage_key, task_metrics.metrics);
}

flight_data.app_metadata.clear();
Ok(())
}
}

impl<S> Stream for MetricsCollectingStream<S>
where
S: Stream<Item = Result<FlightData, FlightError>> + Send,
{
type Item = Result<FlightData, FlightError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.inner.poll_next(cx) {
Poll::Ready(Some(Ok(mut flight_data))) => {
// Extract metrics from app_metadata if present.
match Self::extract_metrics_from_flight_data(
this.metrics_collection.clone(),
&mut flight_data,
) {
Ok(_) => Poll::Ready(Some(Ok(flight_data))),
Err(e) => Poll::Ready(Some(Err(e))),
}
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::protobuf::{
AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
};
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
use arrow_flight::FlightData;
use futures::stream::{self, StreamExt};
use prost::{Message, bytes::Bytes};

fn assert_protocol_error(result: Result<FlightData, FlightError>, expected_msg: &str) {
let FlightError::ProtocolError(msg) = result.unwrap_err() else {
panic!("expected FlightError::ProtocolError");
};
assert!(msg.contains(expected_msg));
}

fn make_flight_data(data: Vec<u8>, metadata: Option<FlightAppMetadata>) -> FlightData {
let metadata_bytes = match metadata {
Some(metadata) => metadata.encode_to_vec().into(),
None => Bytes::new(),
};
FlightData {
flight_descriptor: None,
data_header: Bytes::new(),
app_metadata: metadata_bytes,
data_body: data.into(),
}
}

// Tests that metrics are extracted from FlightData. Metrics are always stored per task, so this
// tests creates some random metrics and tasks and puts them in the app_metadata field of
// FlightData message. Then, it streams these messages through the MetricsCollectingStream
// and asserts that the metrics are collected correctly.
#[tokio::test]
async fn test_metrics_collecting_stream_extracts_and_removes_metadata() {
let stage_keys = vec![
StageKey {
query_id: "test_query".to_string(),
stage_id: 1,
task_number: 1,
},
StageKey {
query_id: "test_query".to_string(),
stage_id: 1,
task_number: 2,
},
];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the purpose of this test? Do you want to test collecting metrics for one query with 2 stages or you want to test collecting metrics for 2 queries each has one stage?

If you want to test both, you may want to split them in 2 different tests and make it clear in the names of the tests or add comments to explain so. Also, do we collect metrics of different queries in one pass and need such tests?

Either way, you want to make the test valid by:

  • Have the same query_id for the test on the same query. And have different query_id if they are for 2 different queries
  • Use smaller number for child stage_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the purpose of this test?

I added a comment.

Do you want to test collecting metrics for one query with 2 stages or you want to test collecting metrics for 2 queries each has one stage?

This metrics stream doesn't really know the semantics of how queries are executed. The only "guarantee" it offers is that it collects metrics by StageKey but it doesn't know what those are. That's why I just used random StageKeys. However, I do agree that it's confusing, so I updated the test to use same query_id and stage_id with two different task ids. This should make more sense.


let app_metadatas = stage_keys
.iter()
.map(|stage_key| FlightAppMetadata {
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
tasks: vec![TaskMetrics {
stage_key: Some(stage_key.clone()),
// use the task number to seed the test metrics set for convenience
metrics: vec![make_test_metrics_set_proto_from_seed(stage_key.task_number)],
}],
})),
})
.collect::<Vec<_>>();

// Create test FlightData messages - some with metadata, some without
let flight_messages = vec![
make_flight_data(vec![1], Some(app_metadatas[0].clone())),
make_flight_data(vec![2], None),
make_flight_data(vec![3], Some(app_metadatas[1].clone())),
]
.into_iter()
.map(Ok);

// Collect all messages from the stream. All should have empty app_metadata.
let metrics_collection = Arc::new(DashMap::new());
let input_stream = stream::iter(flight_messages);
let collecting_stream =
MetricsCollectingStream::new(input_stream, metrics_collection.clone());
let collected_messages: Vec<FlightData> = collecting_stream
.map(|result| result.unwrap())
.collect()
.await;

// Assert the data is unchanged and app_metadata is cleared
assert_eq!(collected_messages.len(), 3);
assert!(
collected_messages
.iter()
.all(|msg| msg.app_metadata.is_empty())
);

// Verify the data in the messages.
assert_eq!(collected_messages[0].data_body, vec![1]);
assert_eq!(collected_messages[1].data_body, vec![2]);
assert_eq!(collected_messages[2].data_body, vec![3]);

// Verify the correct metrics were collected
assert_eq!(metrics_collection.len(), 2);
for stage_key in stage_keys {
let collected_metrics = metrics_collection.get(&stage_key).unwrap();
assert_eq!(collected_metrics.len(), 1);
assert_eq!(
collected_metrics[0],
make_test_metrics_set_proto_from_seed(stage_key.task_number)
);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is a bit long. What do you think to make different functions for creating metrics, stages and result verification, then create a macro to use them. There are many examples for that in Vanilla DF tests and I think Gabriel has that in this repro, too. That way, you can reuse them and it is easier for us to review

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabotechs : I am a bit ambiguous here. It would be great, if you can hep point Jayant to some good examples how to do so

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Plus I made it much more concise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test utils for metrics creation as well :) Now, you can call make_test_metrics_set_proto_from_seed to make test data. Also all metrics protos implement Eq now so it's easy to assert metrics equality in tests


#[tokio::test]
async fn test_metrics_collecting_stream_error_missing_stage_key() {
let task_metrics_with_no_stage_key = TaskMetrics {
stage_key: None,
metrics: vec![make_test_metrics_set_proto_from_seed(1)],
};

let invalid_app_metadata = FlightAppMetadata {
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
tasks: vec![task_metrics_with_no_stage_key],
})),
};

let invalid_flight_data = make_flight_data(vec![1], Some(invalid_app_metadata));

let error_stream = stream::iter(vec![Ok(invalid_flight_data)]);
let mut collecting_stream =
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));

let result = collecting_stream.next().await.unwrap();
assert_protocol_error(
result,
"expected Some StageKey in MetricsCollectingStream, got None",
);
}

#[tokio::test]
async fn test_metrics_collecting_stream_error_decoding_metadata() {
let flight_data_with_invalid_metadata = FlightData {
flight_descriptor: None,
data_header: Bytes::new(),
app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data
data_body: vec![4, 5, 6].into(),
};

let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]);
let mut collecting_stream =
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));

let result = collecting_stream.next().await.unwrap();
assert_protocol_error(result, "failed to decode app_metadata");
}

#[tokio::test]
async fn test_metrics_collecting_stream_error_propagation() {
let metrics_collection = Arc::new(DashMap::new());

// Create a stream that emits an error - should be propagated through
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
let error_stream = stream::iter(vec![Err(stream_error)]);
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);

let result = collecting_stream.next().await.unwrap();
assert_protocol_error(result, "stream error from inner stream");
}
}
1 change: 1 addition & 0 deletions src/execution_plans/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod metrics;
mod metrics_collecting_stream;
mod network_coalesce;
mod network_shuffle;
mod partition_isolator;
Expand Down
Loading