Skip to content

Commit f8faad2

Browse files
added broadcast join operator
1 parent 731ad2a commit f8faad2

File tree

2 files changed

+259
-0
lines changed

2 files changed

+259
-0
lines changed

src/execution_plans/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
mod common;
22
mod distributed;
33
mod metrics;
4+
mod network_broadcast;
45
mod network_coalesce;
56
mod network_shuffle;
67
mod partition_isolator;
78

89
pub use distributed::DistributedExec;
910
pub(crate) use metrics::MetricsWrapperExec;
11+
pub use network_broadcast::{NetworkBroadcastExec, NetworkBroadcastReady};
1012
pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady};
1113
pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec};
1214
pub use partition_isolator::PartitionIsolatorExec;
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
use crate::channel_resolver_ext::get_distributed_channel_resolver;
2+
use crate::config_extension_ext::ContextGrpcMetadata;
3+
use crate::distributed_planner::{InputStageInfo, NetworkBoundary};
4+
use crate::execution_plans::common::{require_one_child};
5+
use crate::flight_service::DoGet;
6+
use crate::metrics::MetricsCollectingStream;
7+
use crate::metrics::proto::MetricsSetProto;
8+
use crate::protobuf::{StageKey, map_flight_to_datafusion_error, map_status_to_datafusion_error};
9+
use crate::stage::{MaybeEncodedPlan, Stage};
10+
use crate::{ChannelResolver};
11+
use arrow_flight::Ticket;
12+
use arrow_flight::decode::FlightRecordBatchStream;
13+
use arrow_flight::error::FlightError;
14+
use bytes::Bytes;
15+
use dashmap::DashMap;
16+
use datafusion::common::{exec_err, internal_datafusion_err, plan_err};
17+
use datafusion::error::DataFusionError;
18+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
19+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20+
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
22+
use http::Extensions;
23+
use prost::Message;
24+
use std::any::Any;
25+
use std::sync::Arc;
26+
use tonic::Request;
27+
use tonic::metadata::MetadataMap;
28+
29+
#[derive(Debug, Clone)]
30+
pub enum NetworkBroadcastExec {
31+
Pending(NetworkBroadcastPending),
32+
Ready(NetworkBroadcastReady),
33+
}
34+
35+
#[derive(Debug, Clone)]
36+
pub struct NetworkBroadcastPending {
37+
properties: PlanProperties,
38+
input_tasks: usize,
39+
input: Arc<dyn ExecutionPlan>,
40+
}
41+
42+
#[derive(Debug, Clone)]
43+
pub struct NetworkBroadcastReady {
44+
pub(crate) properties: PlanProperties,
45+
pub(crate) input_stage: Stage,
46+
pub(crate) metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
47+
}
48+
49+
impl NetworkBroadcastExec {
50+
pub fn new(input: Arc<dyn ExecutionPlan>, input_tasks: usize) -> Self {
51+
Self::Pending(NetworkBroadcastPending {
52+
properties: input.properties().clone(),
53+
input_tasks,
54+
input,
55+
})
56+
}
57+
}
58+
59+
impl NetworkBoundary for NetworkBroadcastExec {
60+
fn get_input_stage_info(
61+
&self,
62+
_n_tasks: usize
63+
) -> datafusion::common::Result<InputStageInfo, DataFusionError> {
64+
let Self::Pending(pending) = self else {
65+
return plan_err!("cannot only return wrapped child if on Pending state");
66+
};
67+
68+
Ok(InputStageInfo {
69+
plan: Arc::clone(&pending.input),
70+
task_count: pending.input_tasks,
71+
})
72+
}
73+
74+
fn with_input_task_count(
75+
&self,
76+
input_tasks: usize,
77+
) -> datafusion::common::Result<Arc<dyn NetworkBoundary>> {
78+
match self {
79+
Self::Pending(pending) => Ok(Arc::new(Self::Pending(NetworkBroadcastPending {
80+
properties: pending.properties.clone(),
81+
input_tasks,
82+
input: pending.input.clone(),
83+
}))),
84+
Self::Ready(_) => {
85+
plan_err!("Self can only re-assign input tasks if in 'Pending' state")
86+
}
87+
}
88+
}
89+
90+
fn input_task_count(&self) -> usize {
91+
match self {
92+
Self::Pending(v) => v.input_tasks,
93+
Self::Ready(v) => v.input_stage.tasks.len(),
94+
}
95+
}
96+
97+
fn with_input_stage(
98+
&self,
99+
input_stage: Stage,
100+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
101+
match self {
102+
Self::Pending(pending) => {
103+
let ready = NetworkBroadcastReady {
104+
properties: pending.properties.clone(),
105+
input_stage,
106+
metrics_collection: Default::default(),
107+
};
108+
Ok(Arc::new(Self::Ready(ready)))
109+
}
110+
Self::Ready(ready) => {
111+
let mut ready = ready.clone();
112+
ready.input_stage = input_stage;
113+
Ok(Arc::new(Self::Ready(ready)))
114+
}
115+
}
116+
}
117+
118+
fn input_stage(&self) -> Option<&Stage> {
119+
match self {
120+
Self::Pending(_) => None,
121+
Self::Ready(v) => Some(&v.input_stage),
122+
}
123+
}
124+
}
125+
126+
impl DisplayAs for NetworkBroadcastExec {
127+
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
128+
match self {
129+
NetworkBroadcastExec::Pending(_) => {
130+
write!(f, "NetworkBroadcastExec: [Pending]")
131+
}
132+
NetworkBroadcastExec::Ready(ready) => {
133+
write!(
134+
f,
135+
"NetworkBroadcastExec: [Stage {}] ({} tasks)",
136+
ready.input_stage.num,
137+
ready.input_stage.tasks.len()
138+
)
139+
}
140+
}
141+
}
142+
}
143+
144+
impl ExecutionPlan for NetworkBroadcastExec {
145+
fn name(&self) -> &str {
146+
"NetworkBroadcastExec"
147+
}
148+
149+
fn as_any(&self) -> &dyn Any {
150+
self
151+
}
152+
153+
fn properties(&self) -> &PlanProperties {
154+
match self {
155+
NetworkBroadcastExec::Pending(v) => v.input.properties(),
156+
NetworkBroadcastExec::Ready(v) => &v.properties,
157+
}
158+
}
159+
160+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
161+
match self {
162+
NetworkBroadcastExec::Pending(v) => vec![&v.input],
163+
NetworkBroadcastExec::Ready(v) => match &v.input_stage.plan {
164+
MaybeEncodedPlan::Decoded(v) => vec![v],
165+
MaybeEncodedPlan::Encoded(_) => vec![],
166+
},
167+
}
168+
}
169+
170+
fn with_new_children(
171+
self: Arc<Self>,
172+
children: Vec<Arc<dyn ExecutionPlan>>,
173+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
174+
match self.as_ref() {
175+
Self::Pending(v) => {
176+
let mut v = v.clone();
177+
v.input = require_one_child(children)?;
178+
Ok(Arc::new(Self::Pending(v)))
179+
}
180+
Self::Ready(v) => {
181+
let mut v = v.clone();
182+
v.input_stage.plan = MaybeEncodedPlan::Decoded(require_one_child(children)?);
183+
Ok(Arc::new(Self::Ready(v)))
184+
}
185+
}
186+
}
187+
188+
fn execute(
189+
&self,
190+
partition: usize,
191+
context: Arc<TaskContext>,
192+
) -> Result<SendableRecordBatchStream, DataFusionError> {
193+
let NetworkBroadcastExec::Ready(self_ready) = self else {
194+
return exec_err!(
195+
"NetworkBroadcastExec is not ready, was the distributed optimization step performed?"
196+
);
197+
};
198+
199+
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
200+
let input_stage = &self_ready.input_stage;
201+
let encoded_input_plan = input_stage.plan.encoded()?;
202+
let input_stage_tasks = input_stage.tasks.to_vec();
203+
let input_task_count = input_stage_tasks.len();
204+
let input_stage_num = input_stage.num as u64;
205+
let query_id = Bytes::from(input_stage.query_id.as_bytes().to_vec());
206+
let context_headers = ContextGrpcMetadata::headers_from_ctx(&context);
207+
208+
let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
209+
let channel_resolver = Arc::clone(&channel_resolver);
210+
let ticket = Request::from_parts(
211+
MetadataMap::from_headers(context_headers.clone()),
212+
Extensions::default(),
213+
Ticket {
214+
ticket: DoGet {
215+
plan_proto: encoded_input_plan.clone(),
216+
target_partition: partition as u64,
217+
stage_key: Some(StageKey::new(
218+
query_id.clone(),
219+
input_stage_num,
220+
i as u64,
221+
)),
222+
target_task_index: i as u64,
223+
target_task_count: input_task_count as u64,
224+
}
225+
.encode_to_vec()
226+
.into(),
227+
},
228+
);
229+
230+
let metrics_collection_capture = self_ready.metrics_collection.clone();
231+
async move {
232+
let url = task.url.ok_or(internal_datafusion_err!(
233+
"NetworkBroadcastExec: task is unassigned, cannot proceed"
234+
))?;
235+
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
236+
let stream = client
237+
.do_get(ticket)
238+
.await
239+
.map_err(map_status_to_datafusion_error)?
240+
.into_inner()
241+
.map_err(|err| FlightError::Tonic(Box::new(err)));
242+
let metrics_collecting_stream =
243+
MetricsCollectingStream::new(stream, metrics_collection_capture);
244+
Ok(
245+
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
246+
.map_err(map_flight_to_datafusion_error),
247+
)
248+
}
249+
.try_flatten_stream()
250+
.boxed()
251+
});
252+
Ok(Box::pin(RecordBatchStreamAdapter::new(
253+
self.schema(),
254+
futures::stream::select_all(stream),
255+
)))
256+
}
257+
}

0 commit comments

Comments
 (0)