Skip to content

Commit 36aef70

Browse files
authored
Merge pull request #6 from robtandy/rob.tandy/arrow-flight
randomly assign stages to exchanges
2 parents 1c7f499 + 27c00ab commit 36aef70

File tree

7 files changed

+119
-126
lines changed

7 files changed

+119
-126
lines changed

datafusion_ray/core.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import uuid
2323
import os
2424
import time
25+
import random
2526

2627
from datafusion_ray._datafusion_ray_internal import (
2728
RayContext as RayContextInternal,
@@ -35,6 +36,7 @@ class RayDataFrame:
3536
def __init__(
3637
self,
3738
ray_internal_df: RayDataFrameInternal,
39+
num_partitions,
3840
batch_size=8192,
3941
isolate_parititions=False,
4042
bucket: str | None = None,
@@ -48,6 +50,7 @@ def __init__(
4850
self.isolate_partitions = isolate_parititions
4951
self.bucket = bucket
5052
self.num_exchangers = num_exchangers
53+
self.num_partitions = num_partitions
5154

5255
def stages(self):
5356
# create our coordinator now, which we need to create stages
@@ -56,7 +59,12 @@ def stages(self):
5659

5760
self.coord = RayStageCoordinator.options(
5861
name="RayQueryCoordinator:" + self.coordinator_id,
59-
).remote(self.coordinator_id, len(self._stages), self.num_exchangers)
62+
).remote(
63+
self.coordinator_id,
64+
len(self._stages),
65+
self.num_exchangers,
66+
self.num_partitions,
67+
)
6068

6169
ray.get(self.coord.start_up.remote())
6270
print("ray coord started up")
@@ -74,7 +82,7 @@ def collect(self) -> list[pa.RecordBatch]:
7482

7583
last_stage = max([stage.stage_id for stage in self._stages])
7684

77-
ref = self.coord.get_exchanger_addr.remote(last_stage)
85+
ref = self.coord.get_exchanger_addr.remote(last_stage, partition=0)
7886
self.create_ray_stages()
7987
t3 = time.time()
8088
print(f"creating ray stage actors took {t3 -t2}s")
@@ -86,7 +94,7 @@ def collect(self) -> list[pa.RecordBatch]:
8694
)
8795

8896
print("calling df execute")
89-
reader = self.df.execute({last_stage: addr})
97+
reader = self.df.execute({(last_stage, 0): addr})
9098
print("called df execute, got reader")
9199
self._batches = list(reader)
92100
self.coord.all_done.remote()
@@ -187,6 +195,7 @@ def sql(self, query: str) -> RayDataFrame:
187195
df = self.ctx.sql(query, coordinator_id)
188196
return RayDataFrame(
189197
df,
198+
self.ctx.get_target_partitions(),
190199
self.batch_size,
191200
self.isolate_partitions,
192201
self.bucket,
@@ -205,12 +214,17 @@ def set(self, option: str, value: str) -> None:
205214
@ray.remote(num_cpus=0)
206215
class RayStageCoordinator:
207216
def __init__(
208-
self, coordinator_id: str, num_stages: int, num_exchangers: int
217+
self,
218+
coordinator_id: str,
219+
num_stages: int,
220+
num_exchangers: int,
221+
num_partitions: int,
209222
) -> None:
210223
self.my_id = coordinator_id
211224
self.stages = {}
212225
self.num_stages = num_stages
213226
self.num_exchangers = num_exchangers
227+
self.num_partitions = num_partitions
214228
self.runtime_env = {}
215229

216230
def start_up(self):
@@ -221,28 +235,27 @@ def start_up(self):
221235
RayExchanger.remote(f"Exchanger #{i}") for i in range(self.num_exchangers)
222236
]
223237

224-
stages_per_exchanger = max(1, self.num_stages // self.num_exchangers)
225-
print("Stages per exchanger: ", stages_per_exchanger)
226-
227238
refs = [exchange.start_up.remote() for exchange in self.xs]
228239

229240
# ensure we've done the necessary initialization before continuing
230241
ray.wait(refs, num_returns=len(refs))
231242
print("all exchanges started up")
232243

233-
self.exchanges = {}
244+
# for each possible stage, and partition, assign it to an exchanger
234245
self.exchange_addrs = {}
235-
for i in range(self.num_stages):
236-
exchanger_i = min(len(self.xs) - 1, i // stages_per_exchanger)
237-
print("exchanger_i = ", exchanger_i)
238-
self.exchanges[i] = self.xs[exchanger_i]
239-
self.exchange_addrs[i] = ray.get(self.xs[exchanger_i].addr.remote())
246+
for stage_num in range(self.num_stages):
247+
for partition_num in range(self.num_partitions):
248+
exchanger_idx = random.choice(range(self.num_exchangers))
249+
self.exchange_addrs[(stage_num, partition_num)] = ray.get(
250+
self.xs[exchanger_idx].addr.remote()
251+
)
252+
print(self.exchange_addrs)
240253

241254
# don't wait for these
242255
[exchange.serve.remote() for exchange in self.xs]
243256

244-
def get_exchanger_addr(self, stage_num: int):
245-
return self.exchange_addrs[stage_num]
257+
def get_exchanger_addr(self, stage_num: int, partition: int):
258+
return self.exchange_addrs[(stage_num, partition)]
246259

247260
def all_done(self):
248261
print("calling exchangers all done")
@@ -269,16 +282,6 @@ def new_stage(
269282
):
270283
stage_key = f"{stage_id}-{shadow_partition}"
271284
try:
272-
if stage_key in self.stages:
273-
print(f"already started stage {stage_key}")
274-
return self.stages[stage_key]
275-
276-
exchange_addr = self.exchange_addrs[stage_id]
277-
278-
input_exchange_addrs = {
279-
input_stage_id: self.exchange_addrs[input_stage_id]
280-
for input_stage_id in input_stage_ids
281-
}
282285

283286
print(f"creating new stage {stage_key} from bytes {len(plan_bytes)}")
284287
stage = RayStage.options(
@@ -287,8 +290,7 @@ def new_stage(
287290
).remote(
288291
stage_id,
289292
plan_bytes,
290-
exchange_addr,
291-
input_exchange_addrs,
293+
self.exchange_addrs,
292294
fraction,
293295
shadow_partition,
294296
bucket,
@@ -337,36 +339,31 @@ def __init__(
337339
self,
338340
stage_id: str,
339341
plan_bytes: bytes,
340-
exchanger_addr: str,
341-
input_exchange_addrs: dict[int, str],
342+
exchanger_addrs: dict[tuple[int, int], str],
342343
fraction: float,
343344
shadow_partition=None,
344345
bucket: str | None = None,
345346
):
346347

347348
from datafusion_ray._datafusion_ray_internal import PyStage
348349

350+
self.shadow_partition = shadow_partition
351+
shadow = (
352+
f", shadowing:{self.shadow_partition}"
353+
if self.shadow_partition is not None
354+
else ""
355+
)
356+
349357
try:
350358
self.stage_id = stage_id
351359
self.pystage = PyStage(
352360
stage_id,
353361
plan_bytes,
354-
exchanger_addr,
355-
input_exchange_addrs,
362+
exchanger_addrs,
356363
shadow_partition,
357364
bucket,
358365
fraction,
359366
)
360-
self.shadow_partition = shadow_partition
361-
shadow = (
362-
f", shadowing:{self.shadow_partition}"
363-
if self.shadow_partition is not None
364-
else ""
365-
)
366-
367-
print(
368-
f"RayStage[{self.stage_id}{shadow}] Sending to {exchanger_addr}, consuming from {input_exchange_addrs}"
369-
)
370367
except Exception as e:
371368
print(
372369
f"RayStage[{self.stage_id}{shadow}] Unhandled Exception in init: {e}!"

src/context.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ impl RayContext {
124124
Ok(())
125125
}
126126

127+
pub fn get_target_partitions(&self) -> usize {
128+
let state = self.ctx.state_ref();
129+
let guard = state.read();
130+
let config = guard.config();
131+
let options = config.options();
132+
options.execution.target_partitions
133+
}
134+
127135
pub fn set_coordinator_id(&self, id: String) -> PyResult<()> {
128136
let state = self.ctx.state_ref();
129137
let mut guard = state.write();

src/dataframe.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use tonic::transport::Channel;
4242

4343
use crate::isolator::PartitionIsolatorExec;
4444
use crate::max_rows::MaxRowsExec;
45-
use crate::pystage::ExchangeFlightClient;
45+
use crate::pystage::ExchangeAddrs;
4646
use crate::ray_stage::RayStageExec;
4747
use crate::ray_stage_reader::RayStageReaderExec;
4848
use crate::util::make_client;
@@ -198,21 +198,9 @@ impl RayDataFrame {
198198

199199
pub fn execute(
200200
&self,
201-
py: Python,
202-
in_exchange_addrs: HashMap<usize, String>,
201+
exchange_addrs: HashMap<(usize, usize), String>,
203202
) -> PyResult<PyRecordBatchStream> {
204-
// TODO: consolidate this code
205-
//
206-
let in_client_map: HashMap<usize, FlightClient> = in_exchange_addrs
207-
.iter()
208-
.map(|(stage_num, addr)| {
209-
let client = make_client(py, addr).to_py_err()?;
210-
Ok::<_, PyErr>((stage_num.clone(), client))
211-
})
212-
.collect::<Result<HashMap<_, _>, PyErr>>()?;
213-
214-
let config =
215-
SessionConfig::new().with_extension(Arc::new(ExchangeFlightClient(in_client_map)));
203+
let config = SessionConfig::new().with_extension(Arc::new(ExchangeAddrs(exchange_addrs)));
216204

217205
let state = SessionStateBuilder::new()
218206
.with_default_features()

src/pystage.rs

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use datafusion::common::tree_node::{Transformed, TreeNode};
1414
use datafusion::common::{internal_datafusion_err, internal_err};
1515
use datafusion::physical_plan::{collect, displayable};
1616
use datafusion_python::physical_plan::PyExecutionPlan;
17-
use futures::future::try_join_all;
17+
use futures::future::{join_all, try_join_all};
1818
use futures::stream::FuturesUnordered;
1919
use object_store::aws::AmazonS3Builder;
2020
use prost::Message;
@@ -34,12 +34,12 @@ use pyo3::prelude::*;
3434

3535
use anyhow::Result;
3636

37-
pub struct ExchangeFlightClient(pub HashMap<usize, FlightClient>);
37+
/// a map of (stage_id, partition_id) to FlightClient used to speak to that Exchanger
38+
pub(crate) struct ExchangeAddrs(pub HashMap<(usize, usize), String>);
3839

3940
#[pyclass]
4041
pub struct PyStage {
4142
name: String,
42-
out_client: FlightClient,
4343
#[pyo3(get)]
4444
stage_id: usize,
4545
pub(crate) plan: Arc<dyn ExecutionPlan>,
@@ -51,13 +51,12 @@ pub struct PyStage {
5151
#[pymethods]
5252
impl PyStage {
5353
#[new]
54-
#[pyo3(signature = (stage_id, plan_bytes, out_exchange_addr, in_exchange_addrs, shadow_partition_number=None, bucket=None, fraction=1.0))]
54+
#[pyo3(signature = (stage_id, plan_bytes, exchange_addrs, shadow_partition_number=None, bucket=None, fraction=1.0))]
5555
pub fn from_bytes(
5656
py: Python,
5757
stage_id: usize,
5858
plan_bytes: Vec<u8>,
59-
out_exchange_addr: String,
60-
in_exchange_addrs: HashMap<usize, String>,
59+
exchange_addrs: HashMap<(usize, usize), String>,
6160
shadow_partition_number: Option<usize>,
6261
bucket: Option<String>,
6362
fraction: f64,
@@ -69,19 +68,9 @@ impl PyStage {
6968
.map(|s| s.to_string())
7069
.unwrap_or("n/a".into())
7170
);
72-
let out_client = make_client(py, &out_exchange_addr)?;
73-
74-
// make our own clone as FlightClient is not Clone, but inner is
75-
let out_client_map: HashMap<usize, FlightClient> = in_exchange_addrs
76-
.iter()
77-
.map(|(stage_num, addr)| {
78-
let client = make_client(py, addr).to_py_err()?;
79-
Ok::<_, PyErr>((stage_num.clone(), client))
80-
})
81-
.collect::<Result<HashMap<_, _>, PyErr>>()?;
8271

8372
let mut config =
84-
SessionConfig::new().with_extension(Arc::new(ExchangeFlightClient(out_client_map)));
73+
SessionConfig::new().with_extension(Arc::new(ExchangeAddrs(exchange_addrs)));
8574

8675
// this only matters if the plan includes an PartitionIsolatorExec
8776
// and will be ignored otherwise
@@ -116,7 +105,6 @@ impl PyStage {
116105

117106
Ok(Self {
118107
name,
119-
out_client,
120108
stage_id,
121109
plan,
122110
ctx,
@@ -128,39 +116,47 @@ impl PyStage {
128116
pub fn execute(&mut self, py: Python) -> PyResult<()> {
129117
println!("{} executing", self.name);
130118

119+
let addrs = &self
120+
.ctx
121+
.state()
122+
.config()
123+
.get_extension::<ExchangeAddrs>()
124+
.ok_or(internal_datafusion_err!("Flight Client not in context"))?
125+
.clone()
126+
.0;
127+
131128
let futs = (0..self.num_output_partitions()).map(|partition| {
132129
let ctx = self.ctx.task_ctx();
133-
// make our own clone as FlightClient is not Clone, but inner is
134-
let inner = self.out_client.inner().clone();
135-
let client_clone = FlightClient::new_from_inner(inner);
136130
let plan = self.plan.clone();
137-
let stage_id = self.stage_id;
138-
let fraction = self.fraction;
139-
let shadow_partition_number = self.shadow_partition_number;
140-
141-
tokio::spawn(consume_stage(
142-
stage_id,
143-
shadow_partition_number,
144-
fraction,
145-
ctx,
146-
partition,
147-
plan,
148-
client_clone,
149-
))
131+
let stage_id = self.stage_id.clone();
132+
let fraction = self.fraction.clone();
133+
let shadow_partition_number = self.shadow_partition_number.clone();
134+
135+
async move {
136+
// TODO propagate these errors appropriately
137+
let client = addrs
138+
.get(&(stage_id, partition))
139+
.map(|addr| make_client(addr))
140+
.expect("cannot find addr")
141+
.await
142+
.expect("cannot make client");
143+
144+
tokio::spawn(consume_stage(
145+
stage_id,
146+
shadow_partition_number,
147+
fraction,
148+
ctx,
149+
partition,
150+
plan,
151+
client,
152+
));
153+
}
150154
});
151155

152-
let name = self.name.clone();
153-
let fut = async {
154-
match try_join_all(futs).await {
155-
Ok(_) => Ok(()),
156-
Err(e) => {
157-
println!("{name}:ERROR executing {e}");
158-
Err(e)
159-
}
160-
}
161-
};
156+
let fut = join_all(futs);
162157

163-
wait_for_future(py, fut).to_py_err()
158+
wait_for_future(py, fut);
159+
Ok(())
164160
}
165161

166162
pub fn num_output_partitions(&self) -> usize {

0 commit comments

Comments
 (0)