Skip to content

Commit a2a8163

Browse files
robtandygabotechs
andauthored
Execution working on all 22 TPCH queries (#89)
* add comment for execution stage struct * Allow passing custom codecs * Better UX for providing user defined codecs * Add docs * fix execution by storing stages in a OnceCell * move validation to integration tests * address lints and add back stream_partitioner_registry * address lints and lock more granularly in do_get * add renamed file --------- Co-authored-by: Gabriel Musat Mestre <[email protected]>
1 parent 92c5a0e commit a2a8163

28 files changed

+408
-345
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
[workspace]
2-
members = [
3-
"benchmarks"
4-
]
2+
members = ["benchmarks"]
53

64
[workspace.dependencies]
75
datafusion = { version = "49.0.0" }
@@ -37,14 +35,16 @@ tpchgen = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d82343252
3735
tpchgen-arrow = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d823432528eed4f70fca5a1296a66c68a389a8", optional = true }
3836
parquet = { version = "55.2.0", optional = true }
3937
arrow = { version = "55.2.0", optional = true }
38+
tokio-stream = { version = "0.1.17", optional = true }
4039

4140
[features]
4241
integration = [
4342
"insta",
4443
"tpchgen",
45-
"tpchgen-arrow",
44+
"tpchgen-arrow",
4645
"parquet",
47-
"arrow"
46+
"arrow",
47+
"tokio-stream",
4848
]
4949

5050
[dev-dependencies]

benchmarks/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,11 @@ After generating the data with the command above:
1414

1515
```shell
1616
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1
17-
```
17+
```
18+
19+
In order to validate the correctness of the results against single node execution, add
20+
`--validate`
21+
22+
```shell
23+
cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 --validate
24+
```

benchmarks/src/tpch/run.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ pub struct RunOpt {
100100
#[structopt(short = "t", long = "sorted")]
101101
sorted: bool,
102102

103-
/// Mark the first column of each table as sorted in ascending order.
104-
/// The tables should have been created with the `--sort` option for this to have any effect.
103+
/// The maximum number of partitions per task.
105104
#[structopt(long = "ppt")]
106105
partitions_per_task: Option<usize>,
107106
}
@@ -115,8 +114,23 @@ impl SessionBuilder for RunOpt {
115114
let mut config = self
116115
.common
117116
.config()?
118-
.with_collect_statistics(!self.disable_statistics);
117+
.with_collect_statistics(!self.disable_statistics)
118+
.with_target_partitions(self.partitions());
119+
120+
// FIXME: these three options are critical for the correct function of the library
121+
// but we are not enforcing that the user sets them. They are here at the moment
122+
// but we should figure out a way to do this better.
123+
config
124+
.options_mut()
125+
.optimizer
126+
.hash_join_single_partition_threshold = 0;
127+
config
128+
.options_mut()
129+
.optimizer
130+
.hash_join_single_partition_threshold_rows = 0;
131+
119132
config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join;
133+
// end critical options section
120134
let rt_builder = self.common.runtime_env_builder()?;
121135

122136
let mut rule = DistributedPhysicalOptimizerRule::new();
@@ -140,7 +154,7 @@ impl SessionBuilder for RunOpt {
140154

141155
impl RunOpt {
142156
pub async fn run(self) -> Result<()> {
143-
let (ctx, _guard) = start_localhost_context([50051], self.clone()).await;
157+
let (ctx, _guard) = start_localhost_context(1, self.clone()).await;
144158
println!("Running benchmarks with the following options: {self:?}");
145159
let query_range = match self.query {
146160
Some(query_id) => query_id..=query_id,
@@ -180,23 +194,22 @@ impl RunOpt {
180194

181195
let sql = &get_query_sql(query_id)?;
182196

197+
let single_node_ctx = SessionContext::new();
198+
self.register_tables(&single_node_ctx).await?;
199+
183200
for i in 0..self.iterations() {
184201
let start = Instant::now();
202+
let mut result = vec![];
185203

186204
// query 15 is special, with 3 statements. the second statement is the one from which we
187205
// want to capture the results
188-
let mut result = vec![];
189-
if query_id == 15 {
190-
for (n, query) in sql.iter().enumerate() {
191-
if n == 1 {
192-
result = self.execute_query(ctx, query).await?;
193-
} else {
194-
self.execute_query(ctx, query).await?;
195-
}
196-
}
197-
} else {
198-
for query in sql {
206+
let result_stmt = if query_id == 15 { 1 } else { sql.len() - 1 };
207+
208+
for (i, query) in sql.iter().enumerate() {
209+
if i == result_stmt {
199210
result = self.execute_query(ctx, query).await?;
211+
} else {
212+
self.execute_query(ctx, query).await?;
200213
}
201214
}
202215

@@ -208,6 +221,7 @@ impl RunOpt {
208221
println!(
209222
"Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows"
210223
);
224+
211225
query_results.push(QueryResult { elapsed, row_count });
212226
}
213227

context.rs

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/errors/arrow_error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ mod tests {
224224
let (recovered_error, recovered_ctx) = proto.to_arrow_error();
225225

226226
if original_error.to_string() != recovered_error.to_string() {
227-
println!("original error: {}", original_error.to_string());
228-
println!("recovered error: {}", recovered_error.to_string());
227+
println!("original error: {}", original_error);
228+
println!("recovered error: {}", recovered_error);
229229
}
230230

231231
assert_eq!(original_error.to_string(), recovered_error.to_string());

src/errors/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ mod schema_error;
1717
pub fn datafusion_error_to_tonic_status(err: &DataFusionError) -> tonic::Status {
1818
let err = DataFusionErrorProto::from_datafusion_error(err);
1919
let err = err.encode_to_vec();
20-
let status = tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into());
21-
status
20+
21+
tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into())
2222
}
2323

2424
/// Decodes a [DataFusionError] from a [tonic::Status] error. If the provided [tonic::Status]

src/errors/schema_error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ impl SchemaErrorProto {
198198
valid_fields,
199199
} => SchemaErrorProto {
200200
inner: Some(SchemaErrorInnerProto::FieldNotFound(FieldNotFoundProto {
201-
field: Some(Box::new(ColumnProto::from_column(&field))),
201+
field: Some(Box::new(ColumnProto::from_column(field))),
202202
valid_fields: valid_fields.iter().map(ColumnProto::from_column).collect(),
203203
})),
204204
backtrace: backtrace.cloned(),

src/flight_service/do_get.rs

Lines changed: 102 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,39 @@
11
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
22
use crate::errors::datafusion_error_to_tonic_status;
33
use crate::flight_service::service::ArrowFlightEndpoint;
4-
use crate::plan::DistributedCodec;
5-
use crate::stage::{stage_from_proto, ExecutionStageProto};
4+
use crate::plan::{DistributedCodec, PartitionGroup};
5+
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
66
use crate::user_provided_codec::get_user_codec;
77
use arrow_flight::encode::FlightDataEncoderBuilder;
88
use arrow_flight::error::FlightError;
99
use arrow_flight::flight_service_server::FlightService;
1010
use arrow_flight::Ticket;
11-
use datafusion::execution::SessionStateBuilder;
11+
use datafusion::execution::{SessionState, SessionStateBuilder};
1212
use datafusion::optimizer::OptimizerConfig;
13-
use datafusion::prelude::SessionContext;
1413
use futures::TryStreamExt;
1514
use prost::Message;
1615
use std::sync::Arc;
1716
use tonic::{Request, Response, Status};
1817

18+
use super::service::StageKey;
19+
1920
#[derive(Clone, PartialEq, ::prost::Message)]
2021
pub struct DoGet {
2122
/// The ExecutionStage that we are going to execute
2223
#[prost(message, optional, tag = "1")]
2324
pub stage_proto: Option<ExecutionStageProto>,
24-
/// the partition of the stage to execute
25+
/// The index to the task within the stage that we want to execute
2526
#[prost(uint64, tag = "2")]
27+
pub task_number: u64,
28+
/// the partition number we want to execute
29+
#[prost(uint64, tag = "3")]
2630
pub partition: u64,
31+
/// The stage key that identifies the stage. This is useful to keep
32+
/// outside of the stage proto as it is used to store the stage
33+
/// and we may not need to deserialize the entire stage proto
34+
/// if we already have stored it
35+
#[prost(message, optional, tag = "4")]
36+
pub stage_key: Option<StageKey>,
2737
}
2838

2939
impl ArrowFlightEndpoint {
@@ -36,59 +46,28 @@ impl ArrowFlightEndpoint {
3646
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
3747
})?;
3848

39-
let stage_msg = doget
40-
.stage_proto
41-
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
42-
43-
let state_builder = SessionStateBuilder::new()
44-
.with_runtime_env(Arc::clone(&self.runtime))
45-
.with_default_features();
46-
let state_builder = self
47-
.session_builder
48-
.session_state_builder(state_builder)
49-
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
50-
51-
let state = state_builder.build();
52-
let mut state = self
53-
.session_builder
54-
.session_state(state)
55-
.await
56-
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
57-
58-
let function_registry = state.function_registry().ok_or(Status::invalid_argument(
59-
"FunctionRegistry not present in newly built SessionState",
60-
))?;
61-
62-
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
63-
combined_codec.push(DistributedCodec);
64-
if let Some(ref user_codec) = get_user_codec(state.config()) {
65-
combined_codec.push_arc(Arc::clone(&user_codec));
66-
}
67-
68-
let stage = stage_from_proto(
69-
stage_msg,
70-
function_registry,
71-
&self.runtime.as_ref(),
72-
&combined_codec,
73-
)
74-
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
75-
let inner_plan = Arc::clone(&stage.plan);
76-
77-
// Add the extensions that might be required for ExecutionPlan nodes in the plan
78-
let config = state.config_mut();
79-
config.set_extension(Arc::clone(&self.channel_manager));
80-
config.set_extension(Arc::new(stage));
81-
82-
let ctx = SessionContext::new_with_state(state);
83-
84-
let ctx = self
85-
.session_builder
86-
.session_context(ctx)
87-
.await
88-
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
49+
let partition = doget.partition as usize;
50+
let task_number = doget.task_number as usize;
51+
let (mut state, stage) = self.get_state_and_stage(doget).await?;
52+
53+
// find out which partition group we are executing
54+
let task = stage
55+
.tasks
56+
.get(task_number)
57+
.ok_or(Status::invalid_argument(format!(
58+
"Task number {} not found in stage {}",
59+
task_number,
60+
stage.name()
61+
)))?;
62+
63+
let partition_group =
64+
PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect());
65+
state.config_mut().set_extension(Arc::new(partition_group));
66+
67+
let inner_plan = stage.plan.clone();
8968

9069
let stream = inner_plan
91-
.execute(doget.partition as usize, ctx.task_ctx())
70+
.execute(partition, state.task_ctx())
9271
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
9372

9473
let flight_data_stream = FlightDataEncoderBuilder::new()
@@ -104,4 +83,71 @@ impl ArrowFlightEndpoint {
10483
},
10584
))))
10685
}
86+
87+
async fn get_state_and_stage(
88+
&self,
89+
doget: DoGet,
90+
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
91+
let key = doget
92+
.stage_key
93+
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
94+
let once_stage = {
95+
let entry = self.stages.entry(key).or_default();
96+
Arc::clone(&entry)
97+
};
98+
99+
let (state, stage) = once_stage
100+
.get_or_try_init(|| async {
101+
let stage_proto = doget
102+
.stage_proto
103+
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
104+
105+
let state_builder = SessionStateBuilder::new()
106+
.with_runtime_env(Arc::clone(&self.runtime))
107+
.with_default_features();
108+
let state_builder = self
109+
.session_builder
110+
.session_state_builder(state_builder)
111+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
112+
113+
let state = state_builder.build();
114+
let mut state = self
115+
.session_builder
116+
.session_state(state)
117+
.await
118+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
119+
120+
let function_registry =
121+
state.function_registry().ok_or(Status::invalid_argument(
122+
"FunctionRegistry not present in newly built SessionState",
123+
))?;
124+
125+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
126+
combined_codec.push(DistributedCodec);
127+
if let Some(ref user_codec) = get_user_codec(state.config()) {
128+
combined_codec.push_arc(Arc::clone(user_codec));
129+
}
130+
131+
let stage = stage_from_proto(
132+
stage_proto,
133+
function_registry,
134+
self.runtime.as_ref(),
135+
&combined_codec,
136+
)
137+
.map(Arc::new)
138+
.map_err(|err| {
139+
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
140+
})?;
141+
142+
// Add the extensions that might be required for ExecutionPlan nodes in the plan
143+
let config = state.config_mut();
144+
config.set_extension(Arc::clone(&self.channel_manager));
145+
config.set_extension(stage.clone());
146+
147+
Ok::<_, Status>((state, stage))
148+
})
149+
.await?;
150+
151+
Ok((state.clone(), stage.clone()))
152+
}
107153
}

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ mod stream_partitioner_registry;
55

66
pub(crate) use do_get::DoGet;
77

8-
pub use service::ArrowFlightEndpoint;
8+
pub use service::{ArrowFlightEndpoint, StageKey};
99
pub use session_builder::{NoopSessionBuilder, SessionBuilder};

0 commit comments

Comments
 (0)