-
Notifications
You must be signed in to change notification settings - Fork 14
Execution working on all 22 TPCH queries #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
9734a95
957bffd
04728a1
770a8d1
8650db2
95667be
17101a1
7ffcf95
5ad36d8
a350566
aa30d88
4692166
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,29 +1,41 @@ | ||
| use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; | ||
| use crate::errors::datafusion_error_to_tonic_status; | ||
| use crate::flight_service::service::ArrowFlightEndpoint; | ||
| use crate::plan::DistributedCodec; | ||
| use crate::stage::{stage_from_proto, ExecutionStageProto}; | ||
| use crate::plan::{DistributedCodec, PartitionGroup}; | ||
| use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; | ||
| use crate::user_provided_codec::get_user_codec; | ||
| use arrow_flight::encode::FlightDataEncoderBuilder; | ||
| use arrow_flight::error::FlightError; | ||
| use arrow_flight::flight_service_server::FlightService; | ||
| use arrow_flight::Ticket; | ||
| use datafusion::execution::SessionStateBuilder; | ||
| use datafusion::execution::{SessionState, SessionStateBuilder}; | ||
| use datafusion::optimizer::OptimizerConfig; | ||
| use datafusion::prelude::SessionContext; | ||
| use datafusion::physical_plan::ExecutionPlan; | ||
| use futures::TryStreamExt; | ||
| use prost::Message; | ||
| use std::sync::Arc; | ||
| use tokio::sync::OnceCell; | ||
| use tonic::{Request, Response, Status}; | ||
|
|
||
| use super::service::StageKey; | ||
|
|
||
| #[derive(Clone, PartialEq, ::prost::Message)] | ||
| pub struct DoGet { | ||
| /// The ExecutionStage that we are going to execute | ||
| #[prost(message, optional, tag = "1")] | ||
| pub stage_proto: Option<ExecutionStageProto>, | ||
| /// the partition of the stage to execute | ||
| /// The index to the task within the stage that we want to execute | ||
| #[prost(uint64, tag = "2")] | ||
| pub task_number: u64, | ||
| /// the partition number we want to execute | ||
| #[prost(uint64, tag = "3")] | ||
| pub partition: u64, | ||
| /// The stage key that identifies the stage. This is useful to keep | ||
| /// outside of the stage proto as it is used to store the stage | ||
| /// and we may not need to deserialize the entire stage proto | ||
| /// if we already have stored it | ||
| #[prost(message, optional, tag = "4")] | ||
| pub stage_key: Option<StageKey>, | ||
| } | ||
|
|
||
| impl ArrowFlightEndpoint { | ||
|
|
@@ -36,59 +48,35 @@ impl ArrowFlightEndpoint { | |
| Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) | ||
| })?; | ||
|
|
||
| let stage_msg = doget | ||
| .stage_proto | ||
| .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; | ||
|
|
||
| let state_builder = SessionStateBuilder::new() | ||
| .with_runtime_env(Arc::clone(&self.runtime)) | ||
| .with_default_features(); | ||
| let state_builder = self | ||
| .session_builder | ||
| .session_state_builder(state_builder) | ||
| .map_err(|err| datafusion_error_to_tonic_status(&err))?; | ||
|
|
||
| let state = state_builder.build(); | ||
| let mut state = self | ||
| .session_builder | ||
| .session_state(state) | ||
| .await | ||
| .map_err(|err| datafusion_error_to_tonic_status(&err))?; | ||
|
|
||
| let function_registry = state.function_registry().ok_or(Status::invalid_argument( | ||
| "FunctionRegistry not present in newly built SessionState", | ||
| ))?; | ||
|
|
||
| let mut combined_codec = ComposedPhysicalExtensionCodec::default(); | ||
| combined_codec.push(DistributedCodec); | ||
| if let Some(ref user_codec) = get_user_codec(state.config()) { | ||
| combined_codec.push_arc(Arc::clone(&user_codec)); | ||
| } | ||
|
|
||
| let stage = stage_from_proto( | ||
| stage_msg, | ||
| function_registry, | ||
| &self.runtime.as_ref(), | ||
| &combined_codec, | ||
| ) | ||
| .map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?; | ||
| let inner_plan = Arc::clone(&stage.plan); | ||
|
|
||
| // Add the extensions that might be required for ExecutionPlan nodes in the plan | ||
| let config = state.config_mut(); | ||
| config.set_extension(Arc::clone(&self.channel_manager)); | ||
| config.set_extension(Arc::new(stage)); | ||
|
|
||
| let ctx = SessionContext::new_with_state(state); | ||
|
|
||
| let ctx = self | ||
| .session_builder | ||
| .session_context(ctx) | ||
| .await | ||
| .map_err(|err| datafusion_error_to_tonic_status(&err))?; | ||
| let partition = doget.partition as usize; | ||
| let task_number = doget.task_number as usize; | ||
| let (mut state, stage) = self.get_state_and_stage(doget).await?; | ||
|
|
||
| // find out which partition group we are executing | ||
| let task = stage | ||
| .tasks | ||
| .get(task_number) | ||
| .ok_or(Status::invalid_argument(format!( | ||
| "Task number {} not found in stage {}", | ||
| task_number, | ||
| stage.name() | ||
| )))?; | ||
|
|
||
| let partition_group = | ||
| PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect()); | ||
| state.config_mut().set_extension(Arc::new(partition_group)); | ||
|
|
||
| let inner_plan = stage.plan.clone(); | ||
|
|
||
| /*println!( | ||
| "{} Task {:?} executing partition {}", | ||
| stage.name(), | ||
| task.partition_group, | ||
| partition | ||
| );*/ | ||
|
||
|
|
||
| let stream = inner_plan | ||
| .execute(doget.partition as usize, ctx.task_ctx()) | ||
| .execute(partition, state.task_ctx()) | ||
| .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; | ||
|
|
||
| let flight_data_stream = FlightDataEncoderBuilder::new() | ||
|
|
@@ -104,4 +92,68 @@ impl ArrowFlightEndpoint { | |
| }, | ||
| )))) | ||
| } | ||
|
|
||
| async fn get_state_and_stage( | ||
| &self, | ||
| doget: DoGet, | ||
| ) -> Result<(SessionState, Arc<ExecutionStage>), Status> { | ||
| let key = doget | ||
| .stage_key | ||
| .ok_or(Status::invalid_argument("DoGet is missing the stage key"))?; | ||
| let once_stage = self.stages.entry(key).or_default(); | ||
|
|
||
| let (state, stage) = once_stage | ||
| .get_or_try_init(|| async { | ||
|
Comment on lines
+99
to
+100
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will lock the Fortunately, it's very easy to prevent this:
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
let once_stage = {
let entry = self.stages.entry(key).or_default();
Arc::clone(&entry)
// <- dashmap RefMut get's dropped, releasing the lock for the current shard
};
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A good improvement. added. |
||
| let stage_proto = doget | ||
| .stage_proto | ||
| .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; | ||
|
|
||
| let state_builder = SessionStateBuilder::new() | ||
| .with_runtime_env(Arc::clone(&self.runtime)) | ||
| .with_default_features(); | ||
| let state_builder = self | ||
| .session_builder | ||
| .session_state_builder(state_builder) | ||
| .map_err(|err| datafusion_error_to_tonic_status(&err))?; | ||
|
|
||
| let state = state_builder.build(); | ||
| let mut state = self | ||
| .session_builder | ||
| .session_state(state) | ||
| .await | ||
| .map_err(|err| datafusion_error_to_tonic_status(&err))?; | ||
|
|
||
| let function_registry = | ||
| state.function_registry().ok_or(Status::invalid_argument( | ||
| "FunctionRegistry not present in newly built SessionState", | ||
| ))?; | ||
|
|
||
| let mut combined_codec = ComposedPhysicalExtensionCodec::default(); | ||
| combined_codec.push(DistributedCodec); | ||
| if let Some(ref user_codec) = get_user_codec(state.config()) { | ||
| combined_codec.push_arc(Arc::clone(user_codec)); | ||
| } | ||
|
|
||
| let stage = stage_from_proto( | ||
| stage_proto, | ||
| function_registry, | ||
| self.runtime.as_ref(), | ||
| &combined_codec, | ||
| ) | ||
| .map(Arc::new) | ||
| .map_err(|err| { | ||
| Status::invalid_argument(format!("Cannot decode stage proto: {err}")) | ||
| })?; | ||
|
|
||
| // Add the extensions that might be required for ExecutionPlan nodes in the plan | ||
| let config = state.config_mut(); | ||
| config.set_extension(Arc::clone(&self.channel_manager)); | ||
| config.set_extension(stage.clone()); | ||
|
|
||
| Ok::<_, Status>((state, stage)) | ||
| }) | ||
| .await?; | ||
|
|
||
| Ok((state.clone(), stage.clone())) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,8 @@ | ||
| mod do_get; | ||
| mod service; | ||
| mod session_builder; | ||
| mod stream_partitioner_registry; | ||
|
|
||
| pub(crate) use do_get::DoGet; | ||
|
|
||
| pub use service::ArrowFlightEndpoint; | ||
| pub use service::{ArrowFlightEndpoint, StageKey}; | ||
| pub use session_builder::{NoopSessionBuilder, SessionBuilder}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about moving forward with @jayshrivastava's changes in #83 for validating TPCH correctness instead of this? it might be slightly better to ensure validation there because:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep i thought the same thing and i moved it out of the benchmarks and aligned with @jayshrivastava 's PR