Skip to content

Commit b82f63f

Browse files
committed
Removing session context from ray context and testing against running sql query
1 parent b84bb6d commit b82f63f

File tree

2 files changed

+71
-53
lines changed

2 files changed

+71
-53
lines changed

examples/tips.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import ray
2020

21+
from datafusion import SessionContext
2122
from datafusion_ray import DatafusionRayContext
2223

2324
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -26,12 +27,14 @@
2627
ray.init()
2728

2829
# Create a context and register a table
29-
ctx = DatafusionRayContext(2)
30+
df_ctx = SessionContext()
31+
32+
ray_ctx = DatafusionRayContext(df_ctx, num_workers=2, use_ray_shuffle=True)
3033
# Register either a CSV or Parquet file
3134
# ctx.register_csv("tips", f"{SCRIPT_DIR}/tips.csv", True)
32-
ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")
35+
df_ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")
3336

34-
result_set = ctx.sql(
37+
result_set = ray_ctx.sql(
3538
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker"
3639
)
3740
for record_batch in result_set:

src/context.rs

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,21 @@ use datafusion::error::{DataFusionError, Result};
2525
use datafusion::execution::context::TaskContext;
2626
use datafusion::execution::disk_manager::DiskManagerConfig;
2727
use datafusion::execution::memory_pool::FairSpillPool;
28+
use datafusion::execution::options::ReadOptions;
29+
use datafusion::execution::registry::MemoryFunctionRegistry;
2830
use datafusion::execution::runtime_env::RuntimeEnv;
31+
use datafusion::execution::FunctionRegistry;
2932
use datafusion::physical_plan::{displayable, ExecutionPlan};
3033
use datafusion::prelude::*;
3134
use datafusion_proto::bytes::{
3235
physical_plan_from_bytes_with_extension_codec, physical_plan_to_bytes_with_extension_codec,
3336
};
37+
use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec};
38+
use datafusion_proto::protobuf;
3439
use datafusion_python::physical_plan::PyExecutionPlan;
3540
use futures::StreamExt;
41+
use prost::{DecodeError, Message};
42+
use pyo3::exceptions::PyRuntimeError;
3643
use pyo3::prelude::*;
3744
use pyo3::types::{PyList, PyLong, PyTuple};
3845
use std::collections::HashMap;
@@ -44,74 +51,82 @@ type PyResultSet = Vec<PyObject>;
4451

4552
#[pyclass(name = "Context", module = "datafusion_ray", subclass)]
4653
pub struct PyContext {
47-
pub(crate) ctx: SessionContext,
54+
pub(crate) py_ctx: PyObject,
4855
}
4956

5057
#[pymethods]
5158
impl PyContext {
5259
#[new]
53-
pub fn new(target_partitions: usize) -> Result<Self> {
54-
let config = SessionConfig::default()
55-
.with_target_partitions(target_partitions)
56-
.with_batch_size(16 * 1024)
57-
.with_repartition_aggregations(true)
58-
.with_repartition_windows(true)
59-
.with_repartition_joins(true)
60-
.with_parquet_pruning(true);
61-
62-
let mem_pool_size = 1024 * 1024 * 1024;
63-
let runtime_config = datafusion::execution::runtime_env::RuntimeConfig::new()
64-
.with_memory_pool(Arc::new(FairSpillPool::new(mem_pool_size)))
65-
.with_disk_manager(DiskManagerConfig::new_specified(vec!["/tmp".into()]));
66-
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
67-
let ctx = SessionContext::new_with_config_rt(config, runtime);
68-
Ok(Self { ctx })
60+
pub fn new(session_ctx: PyObject) -> Result<Self> {
61+
Ok(Self {
62+
py_ctx: session_ctx,
63+
})
6964
}
7065

71-
pub fn register_csv(
72-
&self,
73-
name: &str,
74-
path: &str,
75-
has_header: bool,
76-
py: Python,
77-
) -> PyResult<()> {
78-
let options = CsvReadOptions::default().has_header(has_header);
79-
wait_for_future(py, self.ctx.register_csv(name, path, options))?;
80-
Ok(())
81-
}
66+
// pub fn register_csv(
67+
// &self,
68+
// name: &str,
69+
// path: &str,
70+
// has_header: bool,
71+
// py: Python,
72+
// ) -> PyResult<()> {
73+
// let options = CsvReadOptions::default().has_header(has_header);
74+
// wait_for_future(py, self.ctx.register_csv(name, path, options))?;
75+
// Ok(())
76+
// }
8277

83-
pub fn register_parquet(&self, name: &str, path: &str, py: Python) -> PyResult<()> {
84-
let options = ParquetReadOptions::default();
85-
wait_for_future(py, self.ctx.register_parquet(name, path, options))?;
86-
Ok(())
87-
}
78+
// pub fn register_parquet(&self, name: &str, path: &str, py: Python) -> PyResult<()> {
79+
// let options = ParquetReadOptions::default();
80+
// wait_for_future(py, self.ctx.register_parquet(name, path, options))?;
81+
// Ok(())
82+
// }
8883

89-
pub fn register_datalake_table(
90-
&self,
91-
_name: &str,
92-
_path: Vec<String>,
93-
_py: Python,
94-
) -> PyResult<()> {
95-
// let options = ParquetReadOptions::default();
96-
// let listing_options = options.to_listing_options(&self.ctx.state().config());
97-
// wait_for_future(py, self.ctx.register_listing_table(name, path, listing_options, None, None))?;
98-
// Ok(())
99-
unimplemented!()
100-
}
84+
// pub fn register_datalake_table(
85+
// &self,
86+
// name: &str,
87+
// path: Vec<String>,
88+
// py: Python,
89+
// ) -> PyResult<()> {
90+
// // let options = ParquetReadOptions::default();
91+
// // let listing_options = options.to_listing_options(&self.ctx.state().config());
92+
// // wait_for_future(py, self.ctx.register_listing_table(name, path, listing_options, None, None))?;
93+
// // Ok(())
94+
// unimplemented!()
95+
// }
10196

10297
/// Execute SQL directly against the DataFusion context. Useful for statements
10398
/// such as "create view" or "drop view"
104-
pub fn sql(&self, sql: &str, py: Python) -> PyResult<()> {
105-
println!("Executing {}", sql);
106-
let _df = wait_for_future(py, self.ctx.sql(sql))?;
99+
pub fn sql(&self, query: &str, py: Python) -> PyResult<()> {
100+
println!("Executing {}", query);
101+
// let _df = wait_for_future(py, self.ctx.sql(sql))?;
102+
let _df = self.run_sql(query, py);
107103
Ok(())
108104
}
109105

106+
fn run_sql(&self, query: &str, py: Python) -> PyResult<Py<PyAny>> {
107+
let args = PyTuple::new_bound(py, [query]);
108+
self.py_ctx.call_method1(py, "sql", args)
109+
}
110+
110111
/// Plan a distributed SELECT query for executing against the Ray workers
111112
pub fn plan(&self, sql: &str, py: Python) -> PyResult<PyExecutionGraph> {
112113
println!("Planning {}", sql);
113-
let df = wait_for_future(py, self.ctx.sql(sql))?;
114-
let plan = wait_for_future(py, df.create_physical_plan())?;
114+
// let df = wait_for_future(py, self.ctx.sql(sql))?;
115+
let py_df = self.run_sql(sql, py)?;
116+
let py_plan = py_df.call_method0(py, "execution_plan")?;
117+
let py_proto = py_plan.call_method0(py, "to_proto")?;
118+
let plan_bytes: &[u8] = py_proto.extract(py)?;
119+
let plan_node = protobuf::PhysicalPlanNode::decode(plan_bytes).map_err(|e| {
120+
PyRuntimeError::new_err(format!(
121+
"Unable to decode physical plan protobuf message: {}",
122+
e
123+
))
124+
})?;
125+
126+
let codec = DefaultPhysicalExtensionCodec {};
127+
let runtime = RuntimeEnv::default();
128+
let registry = SessionContext::new();
129+
let plan = plan_node.try_into_physical_plan(&registry, &runtime, &codec)?;
115130

116131
let graph = make_execution_graph(plan.clone())?;
117132

0 commit comments

Comments
 (0)