Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
[workspace]
members = [
"benchmarks"
]
members = ["benchmarks"]

[workspace.dependencies]
datafusion = { version = "49.0.0" }
Expand Down Expand Up @@ -37,14 +35,16 @@ tpchgen = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d82343252
tpchgen-arrow = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d823432528eed4f70fca5a1296a66c68a389a8", optional = true }
parquet = { version = "55.2.0", optional = true }
arrow = { version = "55.2.0", optional = true }
tokio-stream = { version = "0.1.17", optional = true }

[features]
integration = [
"insta",
"tpchgen",
"tpchgen-arrow",
"tpchgen-arrow",
"parquet",
"arrow"
"arrow",
"tokio-stream",
]

[dev-dependencies]
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ After generating the data with the command above:

```shell
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1
```
```

In order to validate the correctness of the results against single node execution, add
`--validate`

```shell
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 --validate
```
44 changes: 29 additions & 15 deletions benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ pub struct RunOpt {
#[structopt(short = "t", long = "sorted")]
sorted: bool,

/// Mark the first column of each table as sorted in ascending order.
/// The tables should have been created with the `--sort` option for this to have any effect.
/// The maximum number of partitions per task.
#[structopt(long = "ppt")]
partitions_per_task: Option<usize>,
}
Expand All @@ -115,8 +114,23 @@ impl SessionBuilder for RunOpt {
let mut config = self
.common
.config()?
.with_collect_statistics(!self.disable_statistics);
.with_collect_statistics(!self.disable_statistics)
.with_target_partitions(self.partitions());

// FIXME: these three options are critical for the correct function of the library
// but we are not enforcing that the user sets them. They are here at the moment
// but we should figure out a way to do this better.
config
.options_mut()
.optimizer
.hash_join_single_partition_threshold = 0;
config
.options_mut()
.optimizer
.hash_join_single_partition_threshold_rows = 0;

config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join;
// end critical options section
let rt_builder = self.common.runtime_env_builder()?;

let mut rule = DistributedPhysicalOptimizerRule::new();
Expand All @@ -140,7 +154,7 @@ impl SessionBuilder for RunOpt {

impl RunOpt {
pub async fn run(self) -> Result<()> {
let (ctx, _guard) = start_localhost_context([50051], self.clone()).await;
let (ctx, _guard) = start_localhost_context(1, self.clone()).await;
println!("Running benchmarks with the following options: {self:?}");
let query_range = match self.query {
Some(query_id) => query_id..=query_id,
Expand Down Expand Up @@ -180,23 +194,22 @@ impl RunOpt {

let sql = &get_query_sql(query_id)?;

let single_node_ctx = SessionContext::new();
self.register_tables(&single_node_ctx).await?;

for i in 0..self.iterations() {
let start = Instant::now();
let mut result = vec![];

// query 15 is special, with 3 statements. the second statement is the one from which we
// want to capture the results
let mut result = vec![];
if query_id == 15 {
for (n, query) in sql.iter().enumerate() {
if n == 1 {
result = self.execute_query(ctx, query).await?;
} else {
self.execute_query(ctx, query).await?;
}
}
} else {
for query in sql {
let result_stmt = if query_id == 15 { 1 } else { sql.len() - 1 };

for (i, query) in sql.iter().enumerate() {
if i == result_stmt {
result = self.execute_query(ctx, query).await?;
} else {
self.execute_query(ctx, query).await?;
}
}

Expand All @@ -208,6 +221,7 @@ impl RunOpt {
println!(
"Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows"
);

query_results.push(QueryResult { elapsed, row_count });
}

Expand Down
20 changes: 0 additions & 20 deletions context.rs

This file was deleted.

4 changes: 2 additions & 2 deletions src/errors/arrow_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ mod tests {
let (recovered_error, recovered_ctx) = proto.to_arrow_error();

if original_error.to_string() != recovered_error.to_string() {
println!("original error: {}", original_error.to_string());
println!("recovered error: {}", recovered_error.to_string());
println!("original error: {}", original_error);
println!("recovered error: {}", recovered_error);
}

assert_eq!(original_error.to_string(), recovered_error.to_string());
Expand Down
4 changes: 2 additions & 2 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ mod schema_error;
pub fn datafusion_error_to_tonic_status(err: &DataFusionError) -> tonic::Status {
let err = DataFusionErrorProto::from_datafusion_error(err);
let err = err.encode_to_vec();
let status = tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into());
status

tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into())
}

/// Decodes a [DataFusionError] from a [tonic::Status] error. If the provided [tonic::Status]
Expand Down
2 changes: 1 addition & 1 deletion src/errors/schema_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl SchemaErrorProto {
valid_fields,
} => SchemaErrorProto {
inner: Some(SchemaErrorInnerProto::FieldNotFound(FieldNotFoundProto {
field: Some(Box::new(ColumnProto::from_column(&field))),
field: Some(Box::new(ColumnProto::from_column(field))),
valid_fields: valid_fields.iter().map(ColumnProto::from_column).collect(),
})),
backtrace: backtrace.cloned(),
Expand Down
158 changes: 102 additions & 56 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
use crate::errors::datafusion_error_to_tonic_status;
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::plan::DistributedCodec;
use crate::stage::{stage_from_proto, ExecutionStageProto};
use crate::plan::{DistributedCodec, PartitionGroup};
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
use crate::user_provided_codec::get_user_codec;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::optimizer::OptimizerConfig;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
use tonic::{Request, Response, Status};

use super::service::StageKey;

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DoGet {
/// The ExecutionStage that we are going to execute
#[prost(message, optional, tag = "1")]
pub stage_proto: Option<ExecutionStageProto>,
/// the partition of the stage to execute
/// The index to the task within the stage that we want to execute
#[prost(uint64, tag = "2")]
pub task_number: u64,
/// the partition number we want to execute
#[prost(uint64, tag = "3")]
pub partition: u64,
/// The stage key that identifies the stage. This is useful to keep
/// outside of the stage proto as it is used to store the stage
/// and we may not need to deserialize the entire stage proto
/// if we already have stored it
#[prost(message, optional, tag = "4")]
pub stage_key: Option<StageKey>,
}

impl ArrowFlightEndpoint {
Expand All @@ -36,59 +46,28 @@ impl ArrowFlightEndpoint {
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
})?;

let stage_msg = doget
.stage_proto
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;

let state_builder = SessionStateBuilder::new()
.with_runtime_env(Arc::clone(&self.runtime))
.with_default_features();
let state_builder = self
.session_builder
.session_state_builder(state_builder)
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let state = state_builder.build();
let mut state = self
.session_builder
.session_state(state)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let function_registry = state.function_registry().ok_or(Status::invalid_argument(
"FunctionRegistry not present in newly built SessionState",
))?;

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(&user_codec));
}

let stage = stage_from_proto(
stage_msg,
function_registry,
&self.runtime.as_ref(),
&combined_codec,
)
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
let inner_plan = Arc::clone(&stage.plan);

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(Arc::new(stage));

let ctx = SessionContext::new_with_state(state);

let ctx = self
.session_builder
.session_context(ctx)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
let partition = doget.partition as usize;
let task_number = doget.task_number as usize;
let (mut state, stage) = self.get_state_and_stage(doget).await?;

// find out which partition group we are executing
let task = stage
.tasks
.get(task_number)
.ok_or(Status::invalid_argument(format!(
"Task number {} not found in stage {}",
task_number,
stage.name()
)))?;

let partition_group =
PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect());
state.config_mut().set_extension(Arc::new(partition_group));

let inner_plan = stage.plan.clone();

let stream = inner_plan
.execute(doget.partition as usize, ctx.task_ctx())
.execute(partition, state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

let flight_data_stream = FlightDataEncoderBuilder::new()
Expand All @@ -104,4 +83,71 @@ impl ArrowFlightEndpoint {
},
))))
}

async fn get_state_and_stage(
&self,
doget: DoGet,
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
let key = doget
.stage_key
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
let once_stage = {
let entry = self.stages.entry(key).or_default();
Arc::clone(&entry)
};

let (state, stage) = once_stage
.get_or_try_init(|| async {
Comment on lines +99 to +100
Copy link
Collaborator

@gabotechs gabotechs Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will lock the once_stage RefMut across the initialization, locking a shard in the self.stages dashmap across an asynchronous gap, which might be too much locking.

Fortunately, it's very easy to prevent this:

  • we can make the OnceCell a shared reference:
    pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
  • and then immediately drop the reference to the dashmap entry
        let once_stage = {
            let entry = self.stages.entry(key).or_default();
            Arc::clone(&entry)
            // <- dashmap RefMut get's dropped, releasing the lock for the current shard
        };

Copy link
Collaborator Author

@robtandy robtandy Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good improvement. added.

let stage_proto = doget
.stage_proto
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;

let state_builder = SessionStateBuilder::new()
.with_runtime_env(Arc::clone(&self.runtime))
.with_default_features();
let state_builder = self
.session_builder
.session_state_builder(state_builder)
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let state = state_builder.build();
let mut state = self
.session_builder
.session_state(state)
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let function_registry =
state.function_registry().ok_or(Status::invalid_argument(
"FunctionRegistry not present in newly built SessionState",
))?;

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(user_codec));
}

let stage = stage_from_proto(
stage_proto,
function_registry,
self.runtime.as_ref(),
&combined_codec,
)
.map(Arc::new)
.map_err(|err| {
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
})?;

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(stage.clone());

Ok::<_, Status>((state, stage))
})
.await?;

Ok((state.clone(), stage.clone()))
}
}
2 changes: 1 addition & 1 deletion src/flight_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ mod stream_partitioner_registry;

pub(crate) use do_get::DoGet;

pub use service::ArrowFlightEndpoint;
pub use service::{ArrowFlightEndpoint, StageKey};
pub use session_builder::{NoopSessionBuilder, SessionBuilder};
Loading