Skip to content

Commit 4f4c4ae

Browse files
committed
add ability to defer proto serialization for customization
1 parent 4f836d9 commit 4f4c4ae

File tree

13 files changed

+154
-66
lines changed

13 files changed

+154
-66
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ path = "src/bin/distributed-datafusion.rs"
4040
anyhow = "1"
4141
arrow = { version = "55.1", features = ["ipc"] }
4242
arrow-flight = { version = "55", features = ["flight-sql-experimental"] }
43+
async-trait = "0.1.88"
4344
async-stream = "0.3"
4445
bytes = "1.5"
4546
clap = { version = "4.4", features = ["derive"] }

src/bin/distributed-datafusion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async fn main() -> Result<()> {
3737
service.serve().await?;
3838
}
3939
"worker" => {
40-
let service = DDWorkerService::new(new_friendly_name()?, args.port).await?;
40+
let service = DDWorkerService::new(new_friendly_name()?, args.port, None).await?;
4141
service.serve().await?;
4242
}
4343
_ => {

src/codec.rs

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,21 @@ use crate::{
3131
};
3232

3333
#[derive(Debug)]
34-
pub struct DDCodec {}
34+
pub struct DDCodec {
35+
sub_codec: Arc<dyn PhysicalExtensionCodec>,
36+
}
37+
38+
impl DDCodec {
39+
pub fn new(sub_codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
40+
Self { sub_codec }
41+
}
42+
}
43+
44+
impl Default for DDCodec {
45+
fn default() -> Self {
46+
Self::new(Arc::new(DefaultPhysicalExtensionCodec {}))
47+
}
48+
}
3549

3650
impl PhysicalExtensionCodec for DDCodec {
3751
fn try_decode(
@@ -127,6 +141,13 @@ impl PhysicalExtensionCodec for DDCodec {
127141
Ok(Arc::new(RecordBatchExec::new(batch)))
128142
}
129143
}
144+
} else if let Ok(ext) = self.sub_codec.try_decode(buf, inputs, registry) {
145+
// If the node is not a DDExecNode, we delegate to the sub codec
146+
trace!(
147+
"Delegated decoding to sub codec for node: {}",
148+
displayable(ext.as_ref()).one_line()
149+
);
150+
Ok(ext)
130151
} else {
131152
internal_err!("cannot decode proto extension in distributed datafusion codec")
132153
}
@@ -151,52 +172,58 @@ impl PhysicalExtensionCodec for DDCodec {
151172
stage_id: reader.stage_id,
152173
};
153174

154-
Payload::StageReaderExec(pb)
175+
Some(Payload::StageReaderExec(pb))
155176
} else if let Some(pi) = node.as_any().downcast_ref::<PartitionIsolatorExec>() {
156177
let pb = PartitionIsolatorExecNode {
157178
partition_count: pi.partition_count as u64,
158179
};
159180

160-
Payload::IsolatorExec(pb)
181+
Some(Payload::IsolatorExec(pb))
161182
} else if let Some(max) = node.as_any().downcast_ref::<MaxRowsExec>() {
162183
let pb = MaxRowsExecNode {
163184
max_rows: max.max_rows as u64,
164185
};
165-
Payload::MaxRowsExec(pb)
186+
Some(Payload::MaxRowsExec(pb))
166187
} else if let Some(exec) = node.as_any().downcast_ref::<DistributedAnalyzeExec>() {
167188
let pb = DistributedAnalyzeExecNode {
168189
verbose: exec.verbose,
169190
show_statistics: exec.show_statistics,
170191
};
171-
Payload::DistributedAnalyzeExec(pb)
192+
Some(Payload::DistributedAnalyzeExec(pb))
172193
} else if let Some(exec) = node.as_any().downcast_ref::<DistributedAnalyzeRootExec>() {
173194
let pb = DistributedAnalyzeRootExecNode {
174195
verbose: exec.verbose,
175196
show_statistics: exec.show_statistics,
176197
};
177-
Payload::DistributedAnalyzeRootExec(pb)
198+
Some(Payload::DistributedAnalyzeRootExec(pb))
178199
} else if let Some(exec) = node.as_any().downcast_ref::<RecordBatchExec>() {
179200
let pb = RecordBatchExecNode {
180201
batch: batch_to_ipc(&exec.batch).map_err(|e| {
181202
internal_datafusion_err!("Failed to encode RecordBatch: {:#?}", e)
182203
})?,
183204
};
184-
Payload::RecordBatchExec(pb)
205+
Some(Payload::RecordBatchExec(pb))
185206
} else {
186-
return internal_err!("Not supported node to encode to proto");
207+
trace!(
208+
"Node {} is not a custom DDExecNode, delegating to sub codec",
209+
displayable(node.as_ref()).one_line()
210+
);
211+
None
187212
};
188213

189-
let pb = DdExecNode {
190-
payload: Some(payload),
191-
};
192-
pb.encode(buf)
193-
.map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {}", e))?;
194-
195-
trace!(
196-
"DONE encoding node: {}",
197-
displayable(node.as_ref()).one_line()
198-
);
199-
Ok(())
214+
match payload {
215+
Some(payload) => {
216+
let pb = DdExecNode {
217+
payload: Some(payload),
218+
};
219+
pb.encode(buf)
220+
.map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {:#?}", e))
221+
}
222+
None => {
223+
// If the node is not one of our custom nodes, we delegate to the sub codec
224+
self.sub_codec.try_encode(node, buf)
225+
}
226+
}
200227
}
201228
}
202229

@@ -225,7 +252,7 @@ mod test {
225252

226253
fn verify_round_trip(exec: Arc<dyn ExecutionPlan>) {
227254
let ctx = SessionContext::new();
228-
let codec = DDCodec {};
255+
let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {}));
229256

230257
// serialize execution plan to proto
231258
let proto: protobuf::PhysicalPlanNode =
@@ -255,7 +282,7 @@ mod test {
255282
let schema = create_test_schema();
256283
let part = Partitioning::UnknownPartitioning(2);
257284
let exec = Arc::new(DDStageReaderExec::try_new(part, schema, 1).unwrap());
258-
let codec = DDCodec {};
285+
let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {}));
259286
let mut buf = vec![];
260287
codec.try_encode(exec.clone(), &mut buf).unwrap();
261288
let ctx = SessionContext::new();

src/ctx_customizer.rs

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/customizer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
use datafusion::prelude::SessionContext;
2+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3+
4+
#[async_trait::async_trait]
5+
pub trait Customizer: PhysicalExtensionCodec + Send + Sync {
6+
/// Customize the context before planning a a query.
7+
async fn customize(&self, ctx: &mut SessionContext) -> Result<(), Box<dyn std::error::Error>>;
8+
}

src/explain.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ use datafusion::{
1111
physical_plan::{displayable, ExecutionPlan},
1212
prelude::SessionContext,
1313
};
14+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1415

1516
use crate::{result::Result, util::bytes_to_physical_plan, vocab::DDTask};
1617

17-
pub fn format_distributed_tasks(tasks: &[DDTask]) -> Result<String> {
18+
pub fn format_distributed_tasks(
19+
tasks: &[DDTask],
20+
codec: &dyn PhysicalExtensionCodec,
21+
) -> Result<String> {
1822
let mut result = String::new();
1923
for (i, task) in tasks.iter().enumerate() {
20-
let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes)
24+
let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes, codec)
2125
.context("unable to decode task plan for formatted output")?;
2226

2327
result.push_str(&format!(
@@ -45,6 +49,7 @@ pub fn build_explain_batch(
4549
physical_plan: &Arc<dyn ExecutionPlan>,
4650
distributed_plan: &Arc<dyn ExecutionPlan>,
4751
distributed_tasks: &[DDTask],
52+
codec: &dyn PhysicalExtensionCodec,
4853
) -> Result<RecordBatch> {
4954
let schema = Arc::new(Schema::new(vec![
5055
Field::new("plan_type", DataType::Utf8, false),
@@ -64,7 +69,7 @@ pub fn build_explain_batch(
6469
displayable(distributed_plan.as_ref())
6570
.indent(true)
6671
.to_string(),
67-
format_distributed_tasks(distributed_tasks)?,
72+
format_distributed_tasks(distributed_tasks, codec)?,
6873
]);
6974

7075
let batch = RecordBatch::try_new(schema, vec![Arc::new(plan_types), Arc::new(plans)])?;

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub use proto::generated::protobuf;
2424

2525
pub mod analyze;
2626
pub mod codec;
27-
pub mod ctx_customizer;
27+
pub mod customizer;
2828
pub mod explain;
2929
pub mod flight;
3030
pub mod friendly;

src/planning.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion::{
2424
},
2525
prelude::{SQLOptions, SessionConfig, SessionContext},
2626
};
27+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
2728
use futures::TryStreamExt;
2829
use itertools::Itertools;
2930
use prost::Message;
@@ -439,6 +440,7 @@ pub async fn distribute_stages(
439440
query_id: &str,
440441
stages: Vec<DDStage>,
441442
worker_addrs: Vec<Host>,
443+
codec: &dyn PhysicalExtensionCodec,
442444
) -> Result<(Addrs, Vec<DDTask>)> {
443445
// map of worker name to address
444446
// FIXME: use types over tuples of strings, as we can accidently swap them and
@@ -457,7 +459,7 @@ pub async fn distribute_stages(
457459

458460
// all stages to workers
459461
let (task_datas, final_addrs) =
460-
assign_to_workers(query_id, &stages, workers.values().collect())?;
462+
assign_to_workers(query_id, &stages, workers.values().collect(), codec)?;
461463

462464
// we retry this a few times to ensure that the workers are ready
463465
// and can accept the stages
@@ -551,6 +553,7 @@ fn assign_to_workers(
551553
query_id: &str,
552554
stages: &[DDStage],
553555
worker_addrs: Vec<&Host>,
556+
codec: &dyn PhysicalExtensionCodec,
554557
) -> Result<(Vec<DDTask>, Addrs)> {
555558
let mut task_datas = vec![];
556559
let mut worker_idx = 0;
@@ -570,7 +573,7 @@ fn assign_to_workers(
570573

571574
for stage in stages {
572575
for partition_group in stage.partition_groups.iter() {
573-
let plan_bytes = physical_plan_to_bytes(stage.plan.clone())?;
576+
let plan_bytes = physical_plan_to_bytes(stage.plan.clone(), codec)?;
574577

575578
let host = worker_addrs[worker_idx].clone();
576579
worker_idx = (worker_idx + 1) % worker_addrs.len();

src/proxy_service.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use tokio::{
4040
use tonic::{async_trait, transport::Server, Request, Response, Status};
4141

4242
use crate::{
43-
ctx_customizer::CtxCustomizer,
43+
customizer::Customizer,
4444
flight::{FlightSqlHandler, FlightSqlServ},
4545
logging::{debug, info, trace},
4646
planning::{add_ctx_extentions, get_ctx},
@@ -59,15 +59,12 @@ pub struct DDProxyHandler {
5959

6060
pub planner: QueryPlanner,
6161

62-
pub ctx_customizer: Option<Arc<dyn CtxCustomizer + Send + Sync>>,
62+
/// Optional customizer for our context and proto serde
63+
pub customizer: Option<Arc<dyn Customizer>>,
6364
}
6465

6566
impl DDProxyHandler {
66-
pub fn new(
67-
name: String,
68-
addr: String,
69-
ctx_customizer: Option<Arc<dyn CtxCustomizer + Send + Sync>>,
70-
) -> Self {
67+
pub fn new(name: String, addr: String, customizer: Option<Arc<dyn Customizer>>) -> Self {
7168
// call this function to bootstrap the worker discovery mechanism
7269
get_worker_addresses().expect("Could not get worker addresses upon startup");
7370

@@ -77,8 +74,8 @@ impl DDProxyHandler {
7774
};
7875
Self {
7976
host: host.clone(),
80-
planner: QueryPlanner::new(),
81-
ctx_customizer,
77+
planner: QueryPlanner::new(customizer.clone()),
78+
customizer,
8279
}
8380
}
8481

@@ -126,8 +123,9 @@ impl DDProxyHandler {
126123
add_ctx_extentions(&mut ctx, &self.host, &query_id, stage_id, addrs, vec![])
127124
.map_err(|e| Status::internal(format!("Could not add context extensions {e:?}")))?;
128125

129-
if let Some(ref c) = self.ctx_customizer {
126+
if let Some(ref c) = self.customizer {
130127
c.customize(&mut ctx)
128+
.await
131129
.map_err(|e| Status::internal(format!("Could not customize context {e:?}")))?;
132130
}
133131

@@ -294,7 +292,7 @@ impl DDProxyService {
294292
pub async fn new(
295293
name: String,
296294
port: usize,
297-
ctx_customizer: Option<Arc<dyn CtxCustomizer + Send + Sync>>,
295+
ctx_customizer: Option<Arc<dyn Customizer>>,
298296
) -> Result<Self> {
299297
debug!("Creating DDProxyService!");
300298

0 commit comments

Comments
 (0)