Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ After generating the data with the command above:

```shell
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1
```
```

In order to validate the correctness of the results against single node execution, add
`--validate`

```shell
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 --validate
```
82 changes: 65 additions & 17 deletions benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@ pub struct RunOpt {
#[structopt(short = "t", long = "sorted")]
sorted: bool,

/// Mark the first column of each table as sorted in ascending order.
/// The tables should have been created with the `--sort` option for this to have any effect.
/// The maximum number of partitions per task.
#[structopt(long = "ppt")]
partitions_per_task: Option<usize>,

#[structopt(long = "validate")]
validate: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about moving forward with @jayshrivastava's changes in #83 for validating TPCH correctness instead of this? it might be slightly better to ensure validation there because:

  • It would be nice to touch this code as little as possible, as this is pretty much vendored code from upstream DataFusion, and if we decide to move this project there or upstream DataFusion decides to make their benchmarks crate public, it would be difficult to port because of conflicts
  • We want to ensure TPCH correctness in the CI, so it might be more suitable to do it as a mandatory test suite using Cargo test tools rather than an optional step during the benchmarks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep i thought the same thing and i moved it out of the benchmarks and aligned with @jayshrivastava 's PR

}

#[async_trait]
Expand All @@ -115,8 +117,23 @@ impl SessionBuilder for RunOpt {
let mut config = self
.common
.config()?
.with_collect_statistics(!self.disable_statistics);
.with_collect_statistics(!self.disable_statistics)
.with_target_partitions(self.partitions());

// FIXME: these three options are critical for the correct function of the library
// but we are not enforcing that the user sets them. They are here at the moment
// but we should figure out a way to do this better.
config
.options_mut()
.optimizer
.hash_join_single_partition_threshold = 0;
config
.options_mut()
.optimizer
.hash_join_single_partition_threshold_rows = 0;

config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join;
// end critical options section
let rt_builder = self.common.runtime_env_builder()?;

let mut rule = DistributedPhysicalOptimizerRule::new();
Expand All @@ -140,7 +157,7 @@ impl SessionBuilder for RunOpt {

impl RunOpt {
pub async fn run(self) -> Result<()> {
let (ctx, _guard) = start_localhost_context([50051], self.clone()).await;
let (ctx, _guard) = start_localhost_context([40051], self.clone()).await;
println!("Running benchmarks with the following options: {self:?}");
let query_range = match self.query {
Some(query_id) => query_id..=query_id,
Expand Down Expand Up @@ -180,23 +197,22 @@ impl RunOpt {

let sql = &get_query_sql(query_id)?;

let single_node_ctx = SessionContext::new();
self.register_tables(&single_node_ctx).await?;

for i in 0..self.iterations() {
let start = Instant::now();
let mut result = vec![];

// query 15 is special, with 3 statements. the second statement is the one from which we
// want to capture the results
let mut result = vec![];
if query_id == 15 {
for (n, query) in sql.iter().enumerate() {
if n == 1 {
result = self.execute_query(ctx, query).await?;
} else {
self.execute_query(ctx, query).await?;
}
}
} else {
for query in sql {
let result_stmt = if query_id == 15 { 1 } else { sql.len() - 1 };

for (i, query) in sql.iter().enumerate() {
if i == result_stmt {
result = self.execute_query(ctx, query).await?;
} else {
self.execute_query(ctx, query).await?;
}
}

Expand All @@ -208,11 +224,43 @@ impl RunOpt {
println!(
"Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows"
);
query_results.push(QueryResult { elapsed, row_count });

let valid = if self.validate {
let mut single_node_result = vec![];
for (i, query) in sql.iter().enumerate() {
if i == result_stmt {
single_node_result = self.execute_query(&single_node_ctx, query).await?;
} else {
self.execute_query(&single_node_ctx, query).await?;
}
}

let res = pretty::pretty_format_batches(&result)?.to_string();
let single_node_res =
pretty::pretty_format_batches(&single_node_result)?.to_string();
res == single_node_res
} else {
false
};

query_results.push(QueryResult {
valid,
elapsed,
row_count,
});
}

let avg = millis.iter().sum::<f64>() / millis.len() as f64;
println!("Query {query_id} avg time: {avg:.2} ms");
let valid_str = if self.validate {
if query_results[query_results.len() - 1].valid {
"valid"
} else {
"invalid"
}
} else {
""
};
println!("Query {query_id} avg time: {avg:.2} ms {valid_str}");

// Print memory stats using mimalloc (only when compiled with --features mimalloc_extended)
print_memory_stats();
Expand Down
1 change: 1 addition & 0 deletions benchmarks/src/util/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub struct BenchQuery {
pub struct QueryResult {
pub elapsed: Duration,
pub row_count: usize,
pub valid: bool,
}
/// collects benchmark run data and then serializes it at the end
pub struct BenchmarkRun {
Expand Down
20 changes: 0 additions & 20 deletions context.rs

This file was deleted.

164 changes: 108 additions & 56 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
use crate::errors::datafusion_error_to_tonic_status;
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::plan::DistributedCodec;
use crate::stage::{stage_from_proto, ExecutionStageProto};
use crate::plan::{DistributedCodec, PartitionGroup};
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
use crate::user_provided_codec::get_user_codec;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::optimizer::OptimizerConfig;
use datafusion::prelude::SessionContext;
use datafusion::physical_plan::ExecutionPlan;
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
use tokio::sync::OnceCell;
use tonic::{Request, Response, Status};

use super::service::StageKey;

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DoGet {
/// The ExecutionStage that we are going to execute
#[prost(message, optional, tag = "1")]
pub stage_proto: Option<ExecutionStageProto>,
/// the partition of the stage to execute
/// The index to the task within the stage that we want to execute
#[prost(uint64, tag = "2")]
pub task_number: u64,
/// the partition number we want to execute
#[prost(uint64, tag = "3")]
pub partition: u64,
/// The stage key that identifies the stage. This is useful to keep
/// outside of the stage proto as it is used to store the stage
/// and we may not need to deserialize the entire stage proto
/// if we already have stored it
#[prost(message, optional, tag = "4")]
pub stage_key: Option<StageKey>,
}

impl ArrowFlightEndpoint {
Expand All @@ -36,59 +48,35 @@ impl ArrowFlightEndpoint {
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
})?;

let stage_msg = doget
.stage_proto
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;

let state_builder = SessionStateBuilder::new()
.with_runtime_env(Arc::clone(&self.runtime))
.with_default_features();
let state_builder = self
.session_builder
.session_state_builder(state_builder)
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let state = state_builder.build();
let mut state = self
.session_builder
.session_state(state)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let function_registry = state.function_registry().ok_or(Status::invalid_argument(
"FunctionRegistry not present in newly built SessionState",
))?;

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(&user_codec));
}

let stage = stage_from_proto(
stage_msg,
function_registry,
&self.runtime.as_ref(),
&combined_codec,
)
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
let inner_plan = Arc::clone(&stage.plan);

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(Arc::new(stage));

let ctx = SessionContext::new_with_state(state);

let ctx = self
.session_builder
.session_context(ctx)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
let partition = doget.partition as usize;
let task_number = doget.task_number as usize;
let (mut state, stage) = self.get_state_and_stage(doget).await?;

// find out which partition group we are executing
let task = stage
.tasks
.get(task_number)
.ok_or(Status::invalid_argument(format!(
"Task number {} not found in stage {}",
task_number,
stage.name()
)))?;

let partition_group =
PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect());
state.config_mut().set_extension(Arc::new(partition_group));

let inner_plan = stage.plan.clone();

/*println!(
"{} Task {:?} executing partition {}",
stage.name(),
task.partition_group,
partition
);*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


let stream = inner_plan
.execute(doget.partition as usize, ctx.task_ctx())
.execute(partition, state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

let flight_data_stream = FlightDataEncoderBuilder::new()
Expand All @@ -104,4 +92,68 @@ impl ArrowFlightEndpoint {
},
))))
}

async fn get_state_and_stage(
&self,
doget: DoGet,
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
let key = doget
.stage_key
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
let once_stage = self.stages.entry(key).or_default();

let (state, stage) = once_stage
.get_or_try_init(|| async {
Comment on lines +99 to +100
Copy link
Collaborator

@gabotechs gabotechs Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will lock the once_stage RefMut across the initialization, locking a shard in the self.stages dashmap across an asynchronous gap, which might be too much locking.

Fortunately, it's very easy to prevent this:

  • we can make the OnceCell a shared reference:
    pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
  • and then immediately drop the reference to the dashmap entry
        let once_stage = {
            let entry = self.stages.entry(key).or_default();
            Arc::clone(&entry)
            // <- dashmap RefMut get's dropped, releasing the lock for the current shard
        };

Copy link
Collaborator Author

@robtandy robtandy Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good improvement. added.

let stage_proto = doget
.stage_proto
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;

let state_builder = SessionStateBuilder::new()
.with_runtime_env(Arc::clone(&self.runtime))
.with_default_features();
let state_builder = self
.session_builder
.session_state_builder(state_builder)
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let state = state_builder.build();
let mut state = self
.session_builder
.session_state(state)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let function_registry =
state.function_registry().ok_or(Status::invalid_argument(
"FunctionRegistry not present in newly built SessionState",
))?;

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(user_codec));
}

let stage = stage_from_proto(
stage_proto,
function_registry,
self.runtime.as_ref(),
&combined_codec,
)
.map(Arc::new)
.map_err(|err| {
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
})?;

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(stage.clone());

Ok::<_, Status>((state, stage))
})
.await?;

Ok((state.clone(), stage.clone()))
}
}
3 changes: 1 addition & 2 deletions src/flight_service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
mod do_get;
mod service;
mod session_builder;
mod stream_partitioner_registry;

pub(crate) use do_get::DoGet;

pub use service::ArrowFlightEndpoint;
pub use service::{ArrowFlightEndpoint, StageKey};
pub use session_builder::{NoopSessionBuilder, SessionBuilder};
Loading