Skip to content

Commit f0d25a2

Browse files
committed
Remove pyarrow dep from datafusion. Add in PyScalarValue wrapper and rename DataFusionError to PyDataFusionError to be less confusing
1 parent 9650a82 commit f0d25a2

25 files changed

+520
-186
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync
3838
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
3939
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
4040
arrow = { version = "53", features = ["pyarrow"] }
41-
datafusion = { version = "44.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
41+
datafusion = { version = "44.0.0", features = ["avro", "unicode_expressions"] }
4242
datafusion-substrait = { version = "44.0.0", optional = true }
4343
datafusion-proto = { version = "44.0.0" }
4444
datafusion-ffi = { version = "44.0.0" }

python/tests/test_indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_err(df):
4444
with pytest.raises(Exception) as e_info:
4545
df["c"]
4646

47-
assert "Schema error: No field named c." in e_info.value.args[0]
47+
for e in ["SchemaError", "FieldNotFound", 'name: "c"']:
48+
assert e in e_info.value.args[0]
4849

4950
with pytest.raises(Exception) as e_info:
5051
df[1]

src/catalog.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::sync::Arc;
2121
use pyo3::exceptions::PyKeyError;
2222
use pyo3::prelude::*;
2323

24-
use crate::errors::DataFusionError;
24+
use crate::errors::PyDataFusionError;
2525
use crate::utils::wait_for_future;
2626
use datafusion::{
2727
arrow::pyarrow::ToPyArrow,
@@ -97,10 +97,12 @@ impl PyDatabase {
9797
}
9898

9999
fn table(&self, name: &str, py: Python) -> PyResult<PyTable> {
100-
if let Some(table) = wait_for_future(py, self.database.table(name))? {
100+
if let Some(table) =
101+
wait_for_future(py, self.database.table(name)).map_err(PyDataFusionError::from)?
102+
{
101103
Ok(PyTable::new(table))
102104
} else {
103-
Err(DataFusionError::Common(format!("Table not found: {name}")).into())
105+
Err(PyDataFusionError::Common(format!("Table not found: {name}")).into())
104106
}
105107
}
106108

src/common/data_type.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ use pyo3::{exceptions::PyValueError, prelude::*};
2323

2424
use crate::errors::py_datafusion_err;
2525

26+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
27+
pub struct PyScalarValue(pub ScalarValue);
28+
29+
impl From<ScalarValue> for PyScalarValue {
30+
fn from(value: ScalarValue) -> Self {
31+
Self(value)
32+
}
33+
}
34+
impl From<PyScalarValue> for ScalarValue {
35+
fn from(value: PyScalarValue) -> Self {
36+
value.0
37+
}
38+
}
39+
2640
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
2741
#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")]
2842
pub enum RexType {

src/config.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use pyo3::types::*;
2121
use datafusion::common::ScalarValue;
2222
use datafusion::config::ConfigOptions;
2323

24+
use crate::errors::PyDataFusionError;
25+
2426
#[pyclass(name = "Config", module = "datafusion", subclass)]
2527
#[derive(Clone)]
2628
pub(crate) struct PyConfig {
@@ -40,7 +42,7 @@ impl PyConfig {
4042
#[staticmethod]
4143
pub fn from_env() -> PyResult<Self> {
4244
Ok(Self {
43-
config: ConfigOptions::from_env()?,
45+
config: ConfigOptions::from_env().map_err(PyDataFusionError::from)?,
4446
})
4547
}
4648

@@ -60,7 +62,8 @@ impl PyConfig {
6062
let scalar_value = py_obj_to_scalar_value(py, value);
6163
self.config
6264
.set(key, scalar_value.to_string().as_str())
63-
.map_err(|e| e.into())
65+
.map_err(PyDataFusionError::from)
66+
.map_err(PyErr::from)
6467
}
6568

6669
/// Get all configuration options

src/context.rs

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, DataFusionError};
37+
use crate::errors::{py_datafusion_err, PyDataFusionError};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
@@ -288,7 +288,11 @@ impl PySessionContext {
288288
} else {
289289
RuntimeEnvBuilder::default()
290290
};
291-
let runtime = Arc::new(runtime_env_builder.build()?);
291+
let runtime = Arc::new(
292+
runtime_env_builder
293+
.build()
294+
.map_err(PyDataFusionError::from)?,
295+
);
292296
let session_state = SessionStateBuilder::new()
293297
.with_config(config)
294298
.with_runtime_env(runtime)
@@ -359,19 +363,19 @@ impl PySessionContext {
359363
.map(|e| e.into_iter().map(|f| f.into()).collect())
360364
.collect(),
361365
);
362-
let table_path = ListingTableUrl::parse(path)?;
366+
let table_path = ListingTableUrl::parse(path).map_err(PyDataFusionError::from)?;
363367
let resolved_schema: SchemaRef = match schema {
364368
Some(s) => Arc::new(s.0),
365369
None => {
366370
let state = self.ctx.state();
367371
let schema = options.infer_schema(&state, &table_path);
368-
wait_for_future(py, schema).map_err(DataFusionError::from)?
372+
wait_for_future(py, schema).map_err(PyDataFusionError::from)?
369373
}
370374
};
371375
let config = ListingTableConfig::new(table_path)
372376
.with_listing_options(options)
373377
.with_schema(resolved_schema);
374-
let table = ListingTable::try_new(config)?;
378+
let table = ListingTable::try_new(config).map_err(PyDataFusionError::from)?;
375379
self.register_table(
376380
name,
377381
&PyTable {
@@ -384,7 +388,7 @@ impl PySessionContext {
384388
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
385389
pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
386390
let result = self.ctx.sql(query);
387-
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
391+
let df = wait_for_future(py, result).map_err(PyDataFusionError::from)?;
388392
Ok(PyDataFrame::new(df))
389393
}
390394

@@ -401,7 +405,7 @@ impl PySessionContext {
401405
SQLOptions::new()
402406
};
403407
let result = self.ctx.sql_with_options(query, options);
404-
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
408+
let df = wait_for_future(py, result).map_err(PyDataFusionError::from)?;
405409
Ok(PyDataFrame::new(df))
406410
}
407411

@@ -419,7 +423,7 @@ impl PySessionContext {
419423
partitions.0[0][0].schema()
420424
};
421425

422-
let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;
426+
let table = MemTable::try_new(schema, partitions.0).map_err(PyDataFusionError::from)?;
423427

424428
// generate a random (unique) name for this table if none is provided
425429
// table name cannot start with numeric digit
@@ -435,9 +439,10 @@ impl PySessionContext {
435439

436440
self.ctx
437441
.register_table(&*table_name, Arc::new(table))
438-
.map_err(DataFusionError::from)?;
442+
.map_err(PyDataFusionError::from)?;
439443

440-
let table = wait_for_future(py, self._table(&table_name)).map_err(DataFusionError::from)?;
444+
let table =
445+
wait_for_future(py, self._table(&table_name)).map_err(PyDataFusionError::from)?;
441446

442447
let df = PyDataFrame::new(table);
443448
Ok(df)
@@ -503,7 +508,7 @@ impl PySessionContext {
503508
let schema = stream_reader.schema().as_ref().to_owned();
504509
let batches = stream_reader
505510
.collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()
506-
.map_err(DataFusionError::from)?;
511+
.map_err(PyDataFusionError::from)?;
507512

508513
(schema, batches)
509514
} else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
@@ -562,14 +567,14 @@ impl PySessionContext {
562567
pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
563568
self.ctx
564569
.register_table(name, table.table())
565-
.map_err(DataFusionError::from)?;
570+
.map_err(PyDataFusionError::from)?;
566571
Ok(())
567572
}
568573

569574
pub fn deregister_table(&mut self, name: &str) -> PyResult<()> {
570575
self.ctx
571576
.deregister_table(name)
572-
.map_err(DataFusionError::from)?;
577+
.map_err(PyDataFusionError::from)?;
573578
Ok(())
574579
}
575580

@@ -587,7 +592,10 @@ impl PySessionContext {
587592
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
588593
let provider: ForeignTableProvider = provider.into();
589594

590-
let _ = self.ctx.register_table(name, Arc::new(provider))?;
595+
let _ = self
596+
.ctx
597+
.register_table(name, Arc::new(provider))
598+
.map_err(PyDataFusionError::from)?;
591599

592600
Ok(())
593601
} else {
@@ -603,10 +611,10 @@ impl PySessionContext {
603611
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
604612
) -> PyResult<()> {
605613
let schema = partitions.0[0][0].schema();
606-
let table = MemTable::try_new(schema, partitions.0)?;
614+
let table = MemTable::try_new(schema, partitions.0).map_err(PyDataFusionError::from)?;
607615
self.ctx
608616
.register_table(name, Arc::new(table))
609-
.map_err(DataFusionError::from)?;
617+
.map_err(PyDataFusionError::from)?;
610618
Ok(())
611619
}
612620

@@ -642,7 +650,7 @@ impl PySessionContext {
642650
.collect();
643651

644652
let result = self.ctx.register_parquet(name, path, options);
645-
wait_for_future(py, result).map_err(DataFusionError::from)?;
653+
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
646654
Ok(())
647655
}
648656

@@ -685,11 +693,11 @@ impl PySessionContext {
685693
if path.is_instance_of::<PyList>() {
686694
let paths = path.extract::<Vec<String>>()?;
687695
let result = self.register_csv_from_multiple_paths(name, paths, options);
688-
wait_for_future(py, result).map_err(DataFusionError::from)?;
696+
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
689697
} else {
690698
let path = path.extract::<String>()?;
691699
let result = self.ctx.register_csv(name, &path, options);
692-
wait_for_future(py, result).map_err(DataFusionError::from)?;
700+
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
693701
}
694702

695703
Ok(())
@@ -726,7 +734,7 @@ impl PySessionContext {
726734
options.schema = schema.as_ref().map(|x| &x.0);
727735

728736
let result = self.ctx.register_json(name, path, options);
729-
wait_for_future(py, result).map_err(DataFusionError::from)?;
737+
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
730738

731739
Ok(())
732740
}
@@ -756,7 +764,7 @@ impl PySessionContext {
756764
options.schema = schema.as_ref().map(|x| &x.0);
757765

758766
let result = self.ctx.register_avro(name, path, options);
759-
wait_for_future(py, result).map_err(DataFusionError::from)?;
767+
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
760768

761769
Ok(())
762770
}
@@ -772,7 +780,7 @@ impl PySessionContext {
772780

773781
self.ctx
774782
.register_table(name, table)
775-
.map_err(DataFusionError::from)?;
783+
.map_err(PyDataFusionError::from)?;
776784

777785
Ok(())
778786
}
@@ -825,11 +833,16 @@ impl PySessionContext {
825833
}
826834

827835
pub fn table_exist(&self, name: &str) -> PyResult<bool> {
828-
Ok(self.ctx.table_exist(name)?)
836+
Ok(self
837+
.ctx
838+
.table_exist(name)
839+
.map_err(PyDataFusionError::from)?)
829840
}
830841

831842
pub fn empty_table(&self) -> PyResult<PyDataFrame> {
832-
Ok(PyDataFrame::new(self.ctx.read_empty()?))
843+
Ok(PyDataFrame::new(
844+
self.ctx.read_empty().map_err(PyDataFusionError::from)?,
845+
))
833846
}
834847

835848
pub fn session_id(&self) -> String {
@@ -859,10 +872,10 @@ impl PySessionContext {
859872
let df = if let Some(schema) = schema {
860873
options.schema = Some(&schema.0);
861874
let result = self.ctx.read_json(path, options);
862-
wait_for_future(py, result).map_err(DataFusionError::from)?
875+
wait_for_future(py, result).map_err(PyDataFusionError::from)?
863876
} else {
864877
let result = self.ctx.read_json(path, options);
865-
wait_for_future(py, result).map_err(DataFusionError::from)?
878+
wait_for_future(py, result).map_err(PyDataFusionError::from)?
866879
};
867880
Ok(PyDataFrame::new(df))
868881
}
@@ -909,12 +922,14 @@ impl PySessionContext {
909922
let paths = path.extract::<Vec<String>>()?;
910923
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
911924
let result = self.ctx.read_csv(paths, options);
912-
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
925+
let df =
926+
PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
913927
Ok(df)
914928
} else {
915929
let path = path.extract::<String>()?;
916930
let result = self.ctx.read_csv(path, options);
917-
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
931+
let df =
932+
PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
918933
Ok(df)
919934
}
920935
}
@@ -952,7 +967,7 @@ impl PySessionContext {
952967
.collect();
953968

954969
let result = self.ctx.read_parquet(path, options);
955-
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
970+
let df = PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
956971
Ok(df)
957972
}
958973

@@ -972,10 +987,10 @@ impl PySessionContext {
972987
let df = if let Some(schema) = schema {
973988
options.schema = Some(&schema.0);
974989
let read_future = self.ctx.read_avro(path, options);
975-
wait_for_future(py, read_future).map_err(DataFusionError::from)?
990+
wait_for_future(py, read_future).map_err(PyDataFusionError::from)?
976991
} else {
977992
let read_future = self.ctx.read_avro(path, options);
978-
wait_for_future(py, read_future).map_err(DataFusionError::from)?
993+
wait_for_future(py, read_future).map_err(PyDataFusionError::from)?
979994
};
980995
Ok(PyDataFrame::new(df))
981996
}
@@ -984,7 +999,7 @@ impl PySessionContext {
984999
let df = self
9851000
.ctx
9861001
.read_table(table.table())
987-
.map_err(DataFusionError::from)?;
1002+
.map_err(PyDataFusionError::from)?;
9881003
Ok(PyDataFrame::new(df))
9891004
}
9901005

@@ -1019,7 +1034,9 @@ impl PySessionContext {
10191034
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
10201035
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
10211036
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1022-
Ok(PyRecordBatchStream::new(stream?))
1037+
Ok(PyRecordBatchStream::new(
1038+
stream.map_err(PyDataFusionError::from)?,
1039+
))
10231040
}
10241041
}
10251042

@@ -1071,13 +1088,13 @@ impl PySessionContext {
10711088

10721089
pub fn convert_table_partition_cols(
10731090
table_partition_cols: Vec<(String, String)>,
1074-
) -> Result<Vec<(String, DataType)>, DataFusionError> {
1091+
) -> Result<Vec<(String, DataType)>, PyDataFusionError> {
10751092
table_partition_cols
10761093
.into_iter()
10771094
.map(|(name, ty)| match ty.as_str() {
10781095
"string" => Ok((name, DataType::Utf8)),
10791096
"int" => Ok((name, DataType::Int32)),
1080-
_ => Err(DataFusionError::Common(format!(
1097+
_ => Err(PyDataFusionError::Common(format!(
10811098
"Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
10821099
))),
10831100
})

0 commit comments

Comments
 (0)