Skip to content

Commit 3990555

Browse files
committed
Removing session context from ray context and testing against running sql query
1 parent d133a94 commit 3990555

File tree

3 files changed

+86
-128
lines changed

3 files changed

+86
-128
lines changed

Cargo.lock

Lines changed: 18 additions & 74 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/tips.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
# under the License.
1717

1818
import os
19-
import pandas as pd
2019
import ray
2120

21+
from datafusion import SessionContext
2222
from datafusion_ray import DatafusionRayContext
2323

2424
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -27,12 +27,14 @@
2727
ray.init(resources={"worker": 1})
2828

2929
# Create a context and register a table
30-
ctx = DatafusionRayContext(2, use_ray_shuffle=True)
30+
df_ctx = SessionContext()
31+
32+
ray_ctx = DatafusionRayContext(df_ctx, use_ray_shuffle=True)
3133
# Register either a CSV or Parquet file
3234
# ctx.register_csv("tips", f"{SCRIPT_DIR}/tips.csv", True)
33-
ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")
35+
df_ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")
3436

35-
result_set = ctx.sql(
37+
result_set = ray_ctx.sql(
3638
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker"
3739
)
3840
for record_batch in result_set:

src/context.rs

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,20 @@ use datafusion::execution::context::TaskContext;
2727
use datafusion::execution::disk_manager::DiskManagerConfig;
2828
use datafusion::execution::memory_pool::FairSpillPool;
2929
use datafusion::execution::options::ReadOptions;
30+
use datafusion::execution::registry::MemoryFunctionRegistry;
3031
use datafusion::execution::runtime_env::RuntimeEnv;
32+
use datafusion::execution::FunctionRegistry;
3133
use datafusion::physical_plan::{displayable, ExecutionPlan};
3234
use datafusion::prelude::*;
3335
use datafusion_proto::bytes::{
3436
physical_plan_from_bytes_with_extension_codec, physical_plan_to_bytes_with_extension_codec,
3537
};
38+
use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec};
39+
use datafusion_proto::protobuf;
3640
use datafusion_python::physical_plan::PyExecutionPlan;
3741
use futures::StreamExt;
42+
use prost::{DecodeError, Message};
43+
use pyo3::exceptions::PyRuntimeError;
3844
use pyo3::prelude::*;
3945
use pyo3::types::{PyList, PyLong, PyTuple};
4046
use std::collections::HashMap;
@@ -46,78 +52,84 @@ type PyResultSet = Vec<PyObject>;
4652

4753
#[pyclass(name = "Context", module = "datafusion_ray", subclass)]
4854
pub struct PyContext {
49-
pub(crate) ctx: SessionContext,
55+
pub(crate) py_ctx: PyObject,
5056
use_ray_shuffle: bool,
5157
}
5258

5359
#[pymethods]
5460
impl PyContext {
5561
#[new]
56-
pub fn new(target_partitions: usize, use_ray_shuffle: bool) -> Result<Self> {
57-
let config = SessionConfig::default()
58-
.with_target_partitions(target_partitions)
59-
.with_batch_size(16 * 1024)
60-
.with_repartition_aggregations(true)
61-
.with_repartition_windows(true)
62-
.with_repartition_joins(true)
63-
.with_parquet_pruning(true);
64-
65-
let mem_pool_size = 1024 * 1024 * 1024;
66-
let runtime_config = datafusion::execution::runtime_env::RuntimeConfig::new()
67-
.with_memory_pool(Arc::new(FairSpillPool::new(mem_pool_size)))
68-
.with_disk_manager(DiskManagerConfig::new_specified(vec!["/tmp".into()]));
69-
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
70-
let ctx = SessionContext::with_config_rt(config, runtime);
62+
pub fn new(session_ctx: PyObject, use_ray_shuffle: bool) -> Result<Self> {
7163
Ok(Self {
72-
ctx,
64+
py_ctx: session_ctx,
7365
use_ray_shuffle,
7466
})
7567
}
7668

77-
pub fn register_csv(
78-
&self,
79-
name: &str,
80-
path: &str,
81-
has_header: bool,
82-
py: Python,
83-
) -> PyResult<()> {
84-
let options = CsvReadOptions::default().has_header(has_header);
85-
wait_for_future(py, self.ctx.register_csv(name, path, options))?;
86-
Ok(())
87-
}
69+
// pub fn register_csv(
70+
// &self,
71+
// name: &str,
72+
// path: &str,
73+
// has_header: bool,
74+
// py: Python,
75+
// ) -> PyResult<()> {
76+
// let options = CsvReadOptions::default().has_header(has_header);
77+
// wait_for_future(py, self.ctx.register_csv(name, path, options))?;
78+
// Ok(())
79+
// }
8880

89-
pub fn register_parquet(&self, name: &str, path: &str, py: Python) -> PyResult<()> {
90-
let options = ParquetReadOptions::default();
91-
wait_for_future(py, self.ctx.register_parquet(name, path, options))?;
92-
Ok(())
93-
}
81+
// pub fn register_parquet(&self, name: &str, path: &str, py: Python) -> PyResult<()> {
82+
// let options = ParquetReadOptions::default();
83+
// wait_for_future(py, self.ctx.register_parquet(name, path, options))?;
84+
// Ok(())
85+
// }
9486

95-
pub fn register_datalake_table(
96-
&self,
97-
name: &str,
98-
path: Vec<String>,
99-
py: Python,
100-
) -> PyResult<()> {
101-
// let options = ParquetReadOptions::default();
102-
// let listing_options = options.to_listing_options(&self.ctx.state().config());
103-
// wait_for_future(py, self.ctx.register_listing_table(name, path, listing_options, None, None))?;
104-
// Ok(())
105-
unimplemented!()
106-
}
87+
// pub fn register_datalake_table(
88+
// &self,
89+
// name: &str,
90+
// path: Vec<String>,
91+
// py: Python,
92+
// ) -> PyResult<()> {
93+
// // let options = ParquetReadOptions::default();
94+
// // let listing_options = options.to_listing_options(&self.ctx.state().config());
95+
// // wait_for_future(py, self.ctx.register_listing_table(name, path, listing_options, None, None))?;
96+
// // Ok(())
97+
// unimplemented!()
98+
// }
10799

108100
/// Execute SQL directly against the DataFusion context. Useful for statements
109101
/// such as "create view" or "drop view"
110-
pub fn sql(&self, sql: &str, py: Python) -> PyResult<()> {
111-
println!("Executing {}", sql);
112-
let _df = wait_for_future(py, self.ctx.sql(sql))?;
102+
pub fn sql(&self, query: &str, py: Python) -> PyResult<()> {
103+
println!("Executing {}", query);
104+
// let _df = wait_for_future(py, self.ctx.sql(sql))?;
105+
let _df = self.run_sql(query, py);
113106
Ok(())
114107
}
115108

109+
fn run_sql(&self, query: &str, py: Python) -> PyResult<Py<PyAny>> {
110+
let args = PyTuple::new_bound(py, [query]);
111+
self.py_ctx.call_method1(py, "sql", args)
112+
}
113+
116114
/// Plan a distributed SELECT query for executing against the Ray workers
117115
pub fn plan(&self, sql: &str, py: Python) -> PyResult<PyExecutionGraph> {
118116
println!("Planning {}", sql);
119-
let df = wait_for_future(py, self.ctx.sql(sql))?;
120-
let plan = wait_for_future(py, df.create_physical_plan())?;
117+
// let df = wait_for_future(py, self.ctx.sql(sql))?;
118+
let py_df = self.run_sql(sql, py)?;
119+
let py_plan = py_df.call_method0(py, "execution_plan")?;
120+
let py_proto = py_plan.call_method0(py, "to_proto")?;
121+
let plan_bytes: &[u8] = py_proto.extract(py)?;
122+
let plan_node = protobuf::PhysicalPlanNode::decode(plan_bytes).map_err(|e| {
123+
PyRuntimeError::new_err(format!(
124+
"Unable to decode physical plan protobuf message: {}",
125+
e
126+
))
127+
})?;
128+
129+
let codec = DefaultPhysicalExtensionCodec {};
130+
let runtime = RuntimeEnv::default();
131+
let registry = SessionContext::new();
132+
let plan = plan_node.try_into_physical_plan(&registry, &runtime, &codec)?;
121133

122134
let graph = make_execution_graph(plan.clone(), self.use_ray_shuffle)?;
123135

0 commit comments

Comments
 (0)