Skip to content

Commit b46ebb5

Browse files
authored
Merge pull request #3 from robtandy/rob.tandy/ray_shuffle
distributed stages working
2 parents 8702412 + 4ad0ac1 commit b46ebb5

File tree

9 files changed

+824
-38
lines changed

9 files changed

+824
-38
lines changed

datafusion_ray/context.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,26 @@ class RayDataFrame:
3636
def __init__(self, ray_internal_df: RayDataFrameInternal):
3737
self.df = ray_internal_df
3838
self.coordinator_id = self.df.coordinator_id
39-
self._stages = []
39+
self._stages = None
40+
self._batches = None
4041

41-
def stages(self):
42+
def stages(self, batch_size=8192):
4243
# create our coordinator now, which we need to create stages
4344
if not self._stages:
4445
self.coord = RayStageCoordinator.options(
4546
name="RayQueryCoordinator:" + self.coordinator_id,
4647
).remote(self.coordinator_id)
47-
self._stages = self.df.stages()
48+
self._stages = self.df.stages(batch_size)
4849
return self._stages
4950

5051
def execution_plan(self):
5152
return self.df.execution_plan()
5253

5354
def collect(self) -> list[pa.RecordBatch]:
54-
reader = self.reader()
55-
self.batches = list(reader)
56-
return self.batches
55+
if not self._batches:
56+
reader = self.reader()
57+
self._batches = list(reader)
58+
return self._batches
5759

5860
def show(self) -> None:
5961
table = pa.Table.from_batches(self.collect())
@@ -111,7 +113,7 @@ def __init__(self, coordinator_id: str) -> None:
111113
self.exchanger = RayExchanger.remote()
112114

113115
def get_exchanger(self):
114-
print("Coord: returning exchanger {self.exchanger}")
116+
print(f"Coord: returning exchanger {self.exchanger}")
115117
return self.exchanger
116118

117119
def new_stage(self, stage_id: str, plan_bytes: bytes):
@@ -123,7 +125,6 @@ def new_stage(self, stage_id: str, plan_bytes: bytes):
123125
print(f"creating new stage {stage_id} from bytes {len(plan_bytes)}")
124126
stage = RayStage.options(
125127
name="stage:" + stage_id,
126-
# lifetime="detached",
127128
).remote(stage_id, plan_bytes, self.my_id, self.exchanger)
128129
self.stages[stage_id] = stage
129130

@@ -155,7 +156,7 @@ def run_stages(self):
155156
raise e
156157

157158

158-
@ray.remote(num_cpus=1)
159+
@ray.remote(num_cpus=0)
159160
class RayStage:
160161
def __init__(
161162
self, stage_id: str, plan_bytes: bytes, coordinator_id: str, exchanger
@@ -178,8 +179,13 @@ def consume(self):
178179
reader = self.pystage.execute(partition)
179180
for batch in reader:
180181
ipc_batch = batch_to_ipc(batch)
182+
o_ref = ray.put(ipc_batch)
183+
184+
# upload a nested object, list[oref] so that ray does not
185+
# materialize it at the destination. The shuffler only
186+
# needs to exchange object refs
181187
ray.get(
182-
self.exchanger.put.remote(self.stage_id, partition, ipc_batch)
188+
self.exchanger.put.remote(self.stage_id, partition, [o_ref])
183189
)
184190
# signal there are no more batches
185191
ray.get(self.exchanger.put.remote(self.stage_id, partition, None))
@@ -209,7 +215,7 @@ async def put(self, stage_id, output_partition, item):
209215

210216
q = self.queues[key]
211217
await q.put(item)
212-
print(f"RayExchanger got batch for {key}")
218+
# print(f"RayExchanger got batch for {key}")
213219

214220
async def get(self, stage_id, output_partition):
215221
key = f"{stage_id}-{output_partition}"
@@ -257,17 +263,20 @@ def __init__(self, exchanger, stage_id, partition):
257263

258264
def __next__(self):
259265
obj_ref = self.exchanger.get.remote(self.stage_id, self.partition)
260-
print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ref")
261-
ipc_batch = ray.get(obj_ref)
266+
# print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ref")
267+
message = ray.get(obj_ref)
262268

263-
if ipc_batch is None:
269+
if message is None:
264270
raise StopIteration
265271

266-
print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ipc batch")
272+
# other wise we know its a list of a single object ref
273+
ipc_batch = ray.get(message[0])
274+
275+
# print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ipc batch")
267276
batch = ipc_to_batch(ipc_batch)
268-
print(
269-
f"[RayIterable stage:{self.stage_id} p:{self.partition}] converted to batch"
270-
)
277+
# print(
278+
# f"[RayIterable stage:{self.stage_id} p:{self.partition}] converted to batch"
279+
# )
271280

272281
return batch
273282

examples/ray_stage.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import argparse
19+
import datafusion
20+
import glob
21+
import os
22+
import ray
23+
import pyarrow as pa
24+
from datafusion_ray import RayContext
25+
26+
27+
def go(data_dir: str, concurrency: int):
28+
ctx = RayContext()
29+
ctx.set("datafusion.execution.target_partitions", str(concurrency))
30+
ctx.set("datafusion.catalog.information_schema", "true")
31+
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
32+
33+
for f in glob.glob(os.path.join(data_dir, "*parquet")):
34+
print(f)
35+
table, _ = os.path.basename(f).split(".")
36+
ctx.register_parquet(table, f)
37+
38+
query = """SELECT customer.c_name, sum(orders.o_totalprice) as total_amount
39+
FROM customer JOIN orders ON customer.c_custkey = orders.o_custkey
40+
GROUP BY customer.c_name limit 10"""
41+
42+
# query = """SELECT count(customer.c_name), customer.c_mktsegment from customer group by customer.c_mktsegment limit 10"""
43+
44+
df = ctx.sql(query)
45+
print(df.execution_plan().display_indent())
46+
for stage in df.stages():
47+
print(f"Stage ", stage.stage_id)
48+
print(stage.execution_plan().display_indent())
49+
b = stage.plan_bytes()
50+
print(f"Stage bytes: {len(b)}")
51+
52+
df.show()
53+
54+
import time
55+
56+
time.sleep(3)
57+
58+
59+
if __name__ == "__main__":
60+
ray.init(namespace="example")
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument("--data-dir", required=True, help="path to tpch*.parquet files")
63+
parser.add_argument("--concurrency", required=True, type=int)
64+
args = parser.parse_args()
65+
66+
go(args.data_dir, args.concurrency)

src/dataframe.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::tree_node::Transformed;
19+
use datafusion::common::tree_node::TreeNode;
20+
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
21+
use datafusion::physical_plan::displayable;
22+
use datafusion::physical_plan::ExecutionPlan;
23+
use datafusion_python::physical_plan::PyExecutionPlan;
24+
use pyo3::prelude::*;
25+
use std::sync::Arc;
26+
27+
use crate::pystage::PyStage;
28+
use crate::ray_stage::RayStageExec;
29+
use crate::ray_stage_reader::RayStageReaderExec;
30+
31+
pub struct CoordinatorId(pub String);
32+
33+
#[pyclass]
34+
pub struct RayDataFrame {
35+
physical_plan: Arc<dyn ExecutionPlan>,
36+
#[pyo3(get)]
37+
coordinator_id: String,
38+
}
39+
40+
impl RayDataFrame {
41+
pub fn new(physical_plan: Arc<dyn ExecutionPlan>, coordinator_id: String) -> Self {
42+
Self {
43+
physical_plan,
44+
coordinator_id,
45+
}
46+
}
47+
}
48+
49+
#[pymethods]
50+
impl RayDataFrame {
51+
#[pyo3(signature = (batch_size=8192))]
52+
fn stages(&self, batch_size: usize) -> PyResult<Vec<PyStage>> {
53+
let mut stages = vec![];
54+
55+
// TODO: This can be done more efficiently, likely in one pass but I'm
56+
// struggling to get the TreeNodeRecursion return values to make it do
57+
// what I want. So, two steps for now
58+
59+
// Step 2: we walk down this stage and replace stages earlier in the tree with
60+
// RayStageReaderExecs
61+
let down = |plan: Arc<dyn ExecutionPlan>| {
62+
//println!("examining plan: {}", displayable(plan.as_ref()).one_line());
63+
64+
if let Some(stage_exec) = plan.as_any().downcast_ref::<RayStageExec>() {
65+
let input = plan.children();
66+
assert!(input.len() == 1, "RayStageExec must have exactly one child");
67+
let input = input[0];
68+
69+
let replacement = Arc::new(RayStageReaderExec::try_new_from_input(
70+
input.clone(),
71+
stage_exec.stage_id.clone(),
72+
self.coordinator_id.clone(),
73+
)?) as Arc<dyn ExecutionPlan>;
74+
75+
Ok(Transformed::yes(replacement))
76+
} else {
77+
Ok(Transformed::no(plan))
78+
}
79+
};
80+
81+
// Step 1: we walk up the tree from the leaves to find the stages
82+
let up = |plan: Arc<dyn ExecutionPlan>| {
83+
println!("examining plan: {}", displayable(plan.as_ref()).one_line());
84+
85+
if let Some(stage_exec) = plan.as_any().downcast_ref::<RayStageExec>() {
86+
let input = plan.children();
87+
assert!(input.len() == 1, "RayStageExec must have exactly one child");
88+
let input = input[0];
89+
90+
let fixed_plan = input.clone().transform_down(down)?.data;
91+
92+
// insert a coalescing batches here too so that we aren't sending
93+
// too small of batches over the network
94+
let final_plan = Arc::new(CoalesceBatchesExec::new(fixed_plan, batch_size))
95+
as Arc<dyn ExecutionPlan>;
96+
97+
let stage = PyStage::new(
98+
stage_exec.stage_id.clone(),
99+
final_plan,
100+
self.coordinator_id.clone(),
101+
);
102+
103+
/*println!(
104+
"made new stage {}: plan:\n{}",
105+
stage_exec.stage_id,
106+
displayable(stage.plan.as_ref()).indent(true)
107+
);*/
108+
109+
stages.push(stage);
110+
}
111+
112+
Ok(Transformed::no(plan))
113+
};
114+
115+
self.physical_plan.clone().transform_up(up)?;
116+
117+
Ok(stages)
118+
}
119+
120+
fn execution_plan(&self) -> PyResult<PyExecutionPlan> {
121+
Ok(PyExecutionPlan::new(self.physical_plan.clone()))
122+
}
123+
}

src/physical.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl PhysicalOptimizerRule for RayShuffleOptimizerRule {
5353
let mut stage_counter = 0;
5454

5555
let up = |plan: Arc<dyn ExecutionPlan>| {
56-
println!("examining plan: {}", displayable(plan.as_ref()).one_line());
56+
//println!("examining plan: {}", displayable(plan.as_ref()).one_line());
5757

5858
if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
5959
let stage = Arc::new(RayStageExec::new(plan, stage_counter.to_string()));
@@ -67,10 +67,10 @@ impl PhysicalOptimizerRule for RayShuffleOptimizerRule {
6767
let plan = plan.transform_up(up)?.data;
6868
let final_plan = Arc::new(RayStageExec::new(plan, stage_counter.to_string()));
6969

70-
println!(
71-
"optimized physical plan:\n{}",
72-
displayable(final_plan.as_ref()).indent(false)
73-
);
70+
//println!(
71+
// "optimized physical plan:\n{}",
72+
// displayable(final_plan.as_ref()).indent(false)
73+
//);
7474
Ok(final_plan)
7575
}
7676

0 commit comments

Comments
 (0)