Skip to content

Commit 9350af4

Browse files
authored
Add copy parameter to functions returning Arrow views (#107)
1 parent 29684ce commit 9350af4

File tree

4 files changed

+87
-33
lines changed

4 files changed

+87
-33
lines changed

python/python/geoindex_rs/rtree.pyi

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,24 @@ IndexLike = Union[np.ndarray, ArrowArrayExportable, Buffer, RTree]
2323
"""A type alias for accepted input as an RTree.
2424
"""
2525

26-
def boxes_at_level(index: IndexLike, level: int) -> Array:
26+
def boxes_at_level(index: IndexLike, level: int, *, copy: bool = False) -> Array:
2727
"""Access the raw bounding box data contained in the RTree at a given tree level.
2828
2929
Args:
3030
index: the RTree to search.
3131
level: The level of the tree to read from. Level 0 is the _base_ of the tree. Each integer higher is one level higher of the tree.
3232
33+
Other Args:
34+
copy: if True, make a _copy_ of the data from the underlying RTree instead of
35+
viewing it directly. Making a copy can be preferred if you'd like to delete
36+
the index itself to save memory.
37+
3338
Returns:
3439
An Arrow FixedSizeListArray containing the bounding box coordinates.
3540
36-
The returned array is a a zero-copy view from Rust. Note that it will keep
37-
the entire index memory alive until the returned array is garbage collected.
41+
If `copy` is `False`, the returned array is a a zero-copy view from Rust.
42+
Note that it will keep the entire index memory alive until the returned
43+
array is garbage collected.
3844
"""
3945

4046
def tree_join(
@@ -106,7 +112,7 @@ def neighbors(
106112
An Arrow array with the insertion indexes of query results.
107113
"""
108114

109-
def partitions(index: IndexLike) -> RecordBatch:
115+
def partitions(index: IndexLike, *, copy=False) -> RecordBatch:
110116
"""Extract the spatial partitions from an RTree.
111117
112118
This can be used to find the sorted groups for spatially partitioning the original
@@ -144,15 +150,21 @@ def partitions(index: IndexLike) -> RecordBatch:
144150
Args:
145151
index: the RTree to use.
146152
153+
Other Args:
154+
copy: if True, make a _copy_ of the data from the underlying RTree instead of
155+
viewing it directly. Making a copy can be preferred if you'd like to delete
156+
the index itself to save memory.
157+
147158
Returns:
148159
An Arrow `RecordBatch` with two columns: `indices` and `partition_ids`. `indices` refers to the insertion index of each row and `partition_ids` refers to the partition each row belongs to.
149160
150-
The `indices` column is constructed as a zero-copy view on the provided
151-
index. Therefore, the `indices` array will have type `uint16` if the tree
152-
has fewer than 16,384 items; otherwise it will have type `uint32`.
161+
If `copy` is `False`, the `indices` column is constructed as a zero-copy
162+
view on the provided index. Therefore, the `indices` array will have type
163+
`uint16` if the tree has fewer than 16,384 items; otherwise it will have
164+
type `uint32`.
153165
"""
154166

155-
def partition_boxes(index: IndexLike) -> RecordBatch:
167+
def partition_boxes(index: IndexLike, *, copy: bool = False) -> RecordBatch:
156168
"""Extract the geometries of the spatial partitions from an RTree.
157169
158170
In order for these boxes to be zero-copy from Rust, they are returned as a
@@ -169,12 +181,18 @@ def partition_boxes(index: IndexLike) -> RecordBatch:
169181
Args:
170182
index: the RTree to use.
171183
184+
Other Args:
185+
copy: if True, make a _copy_ of the data from the underlying RTree instead of
186+
viewing it directly. Making a copy can be preferred if you'd like to delete
187+
the index itself to save memory.
188+
172189
Returns:
173190
An Arrow `RecordBatch` with two columns: `boxes` and `partition_ids`. `boxes` stores the box geometry of each partition and `partition_ids` refers to the partition each row belongs to.
174191
175-
The `boxes` column is constructed as a zero-copy view on the internal boxes
176-
data. The `partition_id` column will be `uint16` type if there are less than
177-
65,536 partitions; otherwise it will be `uint32` type.
192+
If `copy` is `False`, the `boxes` column is constructed as a zero-copy view
193+
on the internal boxes data. The `partition_id` column will be `uint16` type
194+
if there are less than 65,536 partitions; otherwise it will be `uint32`
195+
type.
178196
"""
179197

180198
def search(

python/src/rtree/boxes_at_level.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,25 @@ use crate::rtree::input::PyRTreeRef;
88
use crate::util::boxes_to_arrow;
99

1010
#[pyfunction]
11-
pub fn boxes_at_level(py: Python, index: PyRTreeRef, level: usize) -> PyResult<PyObject> {
11+
#[pyo3(signature = (index, level, *, copy = false))]
12+
pub fn boxes_at_level(
13+
py: Python,
14+
index: PyRTreeRef,
15+
level: usize,
16+
copy: bool,
17+
) -> PyResult<PyObject> {
1218
let array = match index {
1319
PyRTreeRef::Float32(tree) => {
1420
let boxes = tree
1521
.boxes_at_level(level)
1622
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
17-
boxes_to_arrow::<Float32Type>(boxes, tree.buffer().clone())
23+
boxes_to_arrow::<Float32Type>(boxes, tree.buffer().clone(), copy)
1824
}
1925
PyRTreeRef::Float64(tree) => {
2026
let boxes = tree
2127
.boxes_at_level(level)
2228
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
23-
boxes_to_arrow::<Float64Type>(boxes, tree.buffer().clone())
29+
boxes_to_arrow::<Float64Type>(boxes, tree.buffer().clone(), copy)
2430
}
2531
};
2632
PyArray::from_array_ref(array).to_arro3(py)

python/src/rtree/partitions.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,26 @@ use crate::rtree::input::PyRTreeRef;
1515
use crate::util::slice_to_arrow;
1616

1717
#[pyfunction]
18-
pub fn partitions(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
18+
#[pyo3(signature = (index, *, copy = false))]
19+
pub fn partitions(py: Python, index: PyRTreeRef, copy: bool) -> PyResult<PyObject> {
1920
let (indices, partition_ids) = match index {
2021
PyRTreeRef::Float32(tree) => {
21-
let indices = indices_to_arrow(tree.indices(), tree.num_items(), tree.buffer().clone());
22+
let indices = indices_to_arrow(
23+
tree.indices(),
24+
tree.num_items(),
25+
tree.buffer().clone(),
26+
copy,
27+
);
2228
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
2329
(indices, partition_ids)
2430
}
2531
PyRTreeRef::Float64(tree) => {
26-
let indices = indices_to_arrow(tree.indices(), tree.num_items(), tree.buffer().clone());
32+
let indices = indices_to_arrow(
33+
tree.indices(),
34+
tree.num_items(),
35+
tree.buffer().clone(),
36+
copy,
37+
);
2738
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
2839
(indices, partition_ids)
2940
}
@@ -38,10 +49,19 @@ pub fn partitions(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
3849
.to_arro3(py)
3950
}
4051

41-
fn indices_to_arrow(indices: Indices, num_items: u32, owner: Arc<dyn Allocation>) -> ArrayRef {
52+
fn indices_to_arrow(
53+
indices: Indices,
54+
num_items: u32,
55+
owner: Arc<dyn Allocation>,
56+
copy: bool,
57+
) -> ArrayRef {
4258
match indices {
43-
Indices::U16(slice) => slice_to_arrow::<UInt16Type>(&slice[0..num_items as usize], owner),
44-
Indices::U32(slice) => slice_to_arrow::<UInt32Type>(&slice[0..num_items as usize], owner),
59+
Indices::U16(slice) => {
60+
slice_to_arrow::<UInt16Type>(&slice[0..num_items as usize], owner, copy)
61+
}
62+
Indices::U32(slice) => {
63+
slice_to_arrow::<UInt32Type>(&slice[0..num_items as usize], owner, copy)
64+
}
4565
}
4666
}
4767

@@ -83,8 +103,9 @@ fn partition_id_array(num_items: u32, node_size: u16) -> ArrayRef {
83103
// Since for now we assume that the partition level is the node level, we select the boxes at level
84104
// 1.
85105
#[pyfunction]
86-
pub fn partition_boxes(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
87-
let array = boxes_at_level(py, index, 1)?.extract::<PyArray>(py)?;
106+
#[pyo3(signature = (index, *, copy = false))]
107+
pub fn partition_boxes(py: Python, index: PyRTreeRef, copy: bool) -> PyResult<PyObject> {
108+
let array = boxes_at_level(py, index, 1, copy)?.extract::<PyArray>(py)?;
88109
let (array, _field) = array.into_inner();
89110

90111
let partition_ids: ArrayRef = if array.len() < u16::MAX as _ {

python/src/util.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,34 @@ use arrow_schema::Field;
1010
pub(crate) fn slice_to_arrow<T: ArrowPrimitiveType>(
1111
slice: &[T::Native],
1212
owner: Arc<dyn Allocation>,
13+
copy: bool,
1314
) -> ArrayRef {
14-
let ptr = NonNull::new(slice.as_ptr() as *mut _).unwrap();
15-
let len = slice.len();
16-
let bytes_len = len * T::Native::get_byte_width();
15+
if copy {
16+
Arc::new(PrimitiveArray::<T>::new(
17+
ScalarBuffer::from(slice.to_vec()),
18+
None,
19+
))
20+
} else {
21+
let ptr = NonNull::new(slice.as_ptr() as *mut _).unwrap();
22+
let len = slice.len();
23+
let bytes_len = len * T::Native::get_byte_width();
1724

18-
// Safety:
19-
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
20-
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
21-
Arc::new(PrimitiveArray::<T>::new(
22-
ScalarBuffer::new(buffer, 0, len),
23-
None,
24-
))
25+
// Safety:
26+
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
27+
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
28+
Arc::new(PrimitiveArray::<T>::new(
29+
ScalarBuffer::new(buffer, 0, len),
30+
None,
31+
))
32+
}
2533
}
2634

2735
pub(crate) fn boxes_to_arrow<T: ArrowPrimitiveType>(
2836
slice: &[T::Native],
2937
owner: Arc<dyn Allocation>,
38+
copy: bool,
3039
) -> ArrayRef {
31-
let values_array = slice_to_arrow::<T>(slice, owner);
40+
let values_array = slice_to_arrow::<T>(slice, owner, copy);
3241
Arc::new(FixedSizeListArray::new(
3342
Arc::new(Field::new("item", values_array.data_type().clone(), false)),
3443
4,

0 commit comments

Comments
 (0)