Skip to content

Commit 2665166

Browse files
committed
Remove validate_pycapsule
The Bound<'_, PyCapsule>::pointer_checked does the same validation and is already used across the codebase
1 parent 21990b0 commit 2665166

File tree

9 files changed

+28
-67
lines changed

9 files changed

+28
-67
lines changed

crates/core/src/array.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ use arrow::array::{Array, ArrayRef};
2222
use arrow::datatypes::{Field, FieldRef};
2323
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
2424
use arrow::pyarrow::ToPyArrow;
25-
use datafusion_python_util::validate_pycapsule;
26-
use pyo3::ffi::c_str;
2725
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
2826
use pyo3::types::PyCapsule;
2927
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};
@@ -53,10 +51,8 @@ impl PyArrowArrayExportable {
5351
requested_schema: Option<Bound<'py, PyCapsule>>,
5452
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
5553
let field = if let Some(schema_capsule) = requested_schema {
56-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
57-
5854
let data: NonNull<FFI_ArrowSchema> = schema_capsule
59-
.pointer_checked(Some(c_str!("arrow_schema")))?
55+
.pointer_checked(Some(c"arrow_schema"))?
6056
.cast();
6157
let schema_ptr = unsafe { data.as_ref() };
6258
let desired_field = Field::try_from(schema_ptr)?;

crates/core/src/catalog.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
3131
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
3232
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
3333
use datafusion_python_util::{
34-
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, validate_pycapsule,
35-
wait_for_future,
34+
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, wait_for_future,
3635
};
3736
use pyo3::IntoPyObjectExt;
3837
use pyo3::exceptions::PyKeyError;
39-
use pyo3::ffi::c_str;
4038
use pyo3::prelude::*;
4139
use pyo3::types::PyCapsule;
4240

@@ -659,9 +657,8 @@ fn extract_catalog_provider_from_pyobj(
659657
}
660658

661659
let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
662-
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
663660
let data: NonNull<FFI_CatalogProvider> = capsule
664-
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
661+
.pointer_checked(Some(c"datafusion_catalog_provider"))?
665662
.cast();
666663
let provider = unsafe { data.as_ref() };
667664
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
@@ -692,10 +689,8 @@ fn extract_schema_provider_from_pyobj(
692689
}
693690

694691
let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
695-
validate_pycapsule(capsule, "datafusion_schema_provider")?;
696-
697692
let data: NonNull<FFI_SchemaProvider> = capsule
698-
.pointer_checked(Some(c_str!("datafusion_schema_provider")))?
693+
.pointer_checked(Some(c"datafusion_schema_provider"))?
699694
.cast();
700695
let provider = unsafe { data.as_ref() };
701696
let provider: Arc<dyn SchemaProvider + Send> = provider.into();

crates/core/src/context.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
5555
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
5656
use datafusion_python_util::{
5757
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
58-
get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
58+
get_tokio_runtime, spawn_future, wait_for_future,
5959
};
6060
use object_store::ObjectStore;
6161
use pyo3::IntoPyObjectExt;
6262
use pyo3::exceptions::{PyKeyError, PyValueError};
63-
use pyo3::ffi::c_str;
6463
use pyo3::prelude::*;
6564
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
6665
use url::Url;
@@ -675,10 +674,8 @@ impl PySessionContext {
675674

676675
let factory: Arc<dyn TableProviderFactory> =
677676
if let Ok(capsule) = factory.cast::<PyCapsule>().map_err(py_datafusion_err) {
678-
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;
679-
680677
let data: NonNull<FFI_TableProviderFactory> = capsule
681-
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
678+
.pointer_checked(Some(c"datafusion_table_provider_factory"))?
682679
.cast();
683680
let factory = unsafe { data.as_ref() };
684681
factory.into()
@@ -709,12 +706,9 @@ impl PySessionContext {
709706
.call1((codec_capsule,))?;
710707
}
711708

712-
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
713-
{
714-
validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;
715-
709+
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
716710
let data: NonNull<FFI_CatalogProviderList> = capsule
717-
.pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))?
711+
.pointer_checked(Some(c"datafusion_catalog_provider_list"))?
718712
.cast();
719713
let provider = unsafe { data.as_ref() };
720714
let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
@@ -747,12 +741,9 @@ impl PySessionContext {
747741
.call1((codec_capsule,))?;
748742
}
749743

750-
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
751-
{
752-
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
753-
744+
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
754745
let data: NonNull<FFI_CatalogProvider> = capsule
755-
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
746+
.pointer_checked(Some(c"datafusion_catalog_provider"))?
756747
.cast();
757748
let provider = unsafe { data.as_ref() };
758749
let provider: Arc<dyn CatalogProvider + Send> = provider.into();

crates/core/src/dataframe.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ use datafusion::logical_expr::SortExpr;
4141
use datafusion::logical_expr::dml::InsertOp;
4242
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
4343
use datafusion::prelude::*;
44-
use datafusion_python_util::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
44+
use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future};
4545
use futures::{StreamExt, TryStreamExt};
4646
use parking_lot::Mutex;
4747
use pyo3::PyErr;
4848
use pyo3::exceptions::PyValueError;
49-
use pyo3::ffi::c_str;
5049
use pyo3::prelude::*;
5150
use pyo3::pybacked::PyBackedStr;
5251
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -1117,10 +1116,8 @@ impl PyDataFrame {
11171116
let mut projection: Option<SchemaRef> = None;
11181117

11191118
if let Some(schema_capsule) = requested_schema {
1120-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
1121-
11221119
let data: NonNull<FFI_ArrowSchema> = schema_capsule
1123-
.pointer_checked(Some(c_str!("arrow_schema")))?
1120+
.pointer_checked(Some(c"arrow_schema"))?
11241121
.cast();
11251122
let schema_ptr = unsafe { data.as_ref() };
11261123
let desired_schema = Schema::try_from(schema_ptr)?;

crates/core/src/udaf.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ use datafusion::logical_expr::{
2727
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
2828
};
2929
use datafusion_ffi::udaf::FFI_AggregateUDF;
30-
use datafusion_python_util::{parse_volatility, validate_pycapsule};
31-
use pyo3::ffi::c_str;
30+
use datafusion_python_util::parse_volatility;
3231
use pyo3::prelude::*;
3332
use pyo3::types::{PyCapsule, PyTuple};
3433

@@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
157156
}
158157

159158
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
160-
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
161-
162159
let data: NonNull<FFI_AggregateUDF> = capsule
163-
.pointer_checked(Some(c_str!("datafusion_aggregate_udf")))?
160+
.pointer_checked(Some(c"datafusion_aggregate_udf"))?
164161
.cast();
165162
let udaf = unsafe { data.as_ref() };
166163
let udaf: Arc<dyn AggregateUDFImpl> = udaf.into();

crates/core/src/udf.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,12 @@ use datafusion::logical_expr::{
3232
Volatility,
3333
};
3434
use datafusion_ffi::udf::FFI_ScalarUDF;
35-
use datafusion_python_util::{parse_volatility, validate_pycapsule};
36-
use pyo3::ffi::c_str;
35+
use datafusion_python_util::parse_volatility;
3736
use pyo3::prelude::*;
3837
use pyo3::types::{PyCapsule, PyTuple};
3938

4039
use crate::array::PyArrowArrayExportable;
41-
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
40+
use crate::errors::{PyDataFusionResult, to_datafusion_err};
4241
use crate::expr::PyExpr;
4342

4443
/// This struct holds the Python written function that is a
@@ -194,11 +193,9 @@ impl PyScalarUDF {
194193
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
195194
if func.hasattr("__datafusion_scalar_udf__")? {
196195
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
197-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
198-
validate_pycapsule(capsule, "datafusion_scalar_udf")?;
199-
196+
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
200197
let data: NonNull<FFI_ScalarUDF> = capsule
201-
.pointer_checked(Some(c_str!("datafusion_scalar_udf")))?
198+
.pointer_checked(Some(c"datafusion_scalar_udf"))?
202199
.cast();
203200
let udf = unsafe { data.as_ref() };
204201
let udf: Arc<dyn ScalarUDFImpl> = udf.into();

crates/core/src/udtf.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ use datafusion::catalog::{TableFunctionImpl, TableProvider};
2222
use datafusion::error::Result as DataFusionResult;
2323
use datafusion::logical_expr::Expr;
2424
use datafusion_ffi::udtf::FFI_TableFunction;
25-
use datafusion_python_util::validate_pycapsule;
2625
use pyo3::IntoPyObjectExt;
2726
use pyo3::exceptions::{PyImportError, PyTypeError};
28-
use pyo3::ffi::c_str;
2927
use pyo3::prelude::*;
3028
use pyo3::types::{PyCapsule, PyTuple, PyType};
3129

@@ -73,11 +71,9 @@ impl PyTableFunction {
7371
err
7472
}
7573
})?;
76-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
77-
validate_pycapsule(capsule, "datafusion_table_function")?;
78-
74+
let capsule = capsule.cast::<PyCapsule>()?;
7975
let data: NonNull<FFI_TableFunction> = capsule
80-
.pointer_checked(Some(c_str!("datafusion_table_function")))?
76+
.pointer_checked(Some(c"datafusion_table_function"))?
8177
.cast();
8278
let ffi_func = unsafe { data.as_ref() };
8379
let foreign_func: Arc<dyn TableFunctionImpl> = ffi_func.to_owned().into();

crates/core/src/udwf.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ use datafusion::logical_expr::{
3232
};
3333
use datafusion::scalar::ScalarValue;
3434
use datafusion_ffi::udwf::FFI_WindowUDF;
35-
use datafusion_python_util::{parse_volatility, validate_pycapsule};
35+
use datafusion_python_util::parse_volatility;
3636
use pyo3::exceptions::PyValueError;
37-
use pyo3::ffi::c_str;
3837
use pyo3::prelude::*;
3938
use pyo3::types::{PyCapsule, PyList, PyTuple};
4039

4140
use crate::common::data_type::PyScalarValue;
42-
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
41+
use crate::errors::{PyDataFusionResult, to_datafusion_err};
4342
use crate::expr::PyExpr;
4443

4544
#[derive(Debug)]
@@ -262,11 +261,9 @@ impl PyWindowUDF {
262261
func
263262
};
264263

265-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
266-
validate_pycapsule(capsule, "datafusion_window_udf")?;
267-
264+
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
268265
let data: NonNull<FFI_WindowUDF> = capsule
269-
.pointer_checked(Some(c_str!("datafusion_window_udf")))?
266+
.pointer_checked(Some(c"datafusion_window_udf"))?
270267
.cast();
271268
let udwf = unsafe { data.as_ref() };
272269
let udwf: Arc<dyn WindowUDFImpl> = udwf.into();

crates/util/src/lib.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ use datafusion::logical_expr::Volatility;
2626
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
2727
use datafusion_ffi::table_provider::FFI_TableProvider;
2828
use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError};
29-
use pyo3::ffi::c_str;
3029
use pyo3::prelude::*;
3130
use pyo3::types::{PyCapsule, PyType};
3231
use tokio::runtime::Runtime;
3332
use tokio::task::JoinHandle;
3433
use tokio::time::sleep;
3534

36-
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
35+
use crate::errors::{PyDataFusionError, PyDataFusionResult, to_datafusion_err};
3736

3837
pub mod errors;
3938

@@ -186,11 +185,9 @@ pub fn table_provider_from_pycapsule<'py>(
186185
})?;
187186
}
188187

189-
if let Ok(capsule) = obj.cast::<PyCapsule>().map_err(py_datafusion_err) {
190-
validate_pycapsule(capsule, "datafusion_table_provider")?;
191-
188+
if let Ok(capsule) = obj.cast::<PyCapsule>() {
192189
let data: NonNull<FFI_TableProvider> = capsule
193-
.pointer_checked(Some(c_str!("datafusion_table_provider")))?
190+
.pointer_checked(Some(c"datafusion_table_provider"))?
194191
.cast();
195192
let provider = unsafe { data.as_ref() };
196193
let provider: Arc<dyn TableProvider> = provider.into();
@@ -220,10 +217,8 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound<PyAny>) -> PyResult<FFI_Logic
220217
};
221218

222219
let capsule = capsule.cast::<PyCapsule>()?;
223-
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
224-
225220
let data: NonNull<FFI_LogicalExtensionCodec> = capsule
226-
.pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))?
221+
.pointer_checked(Some(c"datafusion_logical_extension_codec"))?
227222
.cast();
228223
let codec = unsafe { data.as_ref() };
229224

0 commit comments

Comments
 (0)