Skip to content

Commit 0f87a7e

Browse files
authored
feat(python, geoarrow-rust-core): Implement geometry_col (#1381)
to access a geometry column from a RecordBatch or stream
1 parent 6965ff9 commit 0f87a7e

File tree

8 files changed

+247
-72
lines changed

8 files changed

+247
-72
lines changed

python/Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/geoarrow-core/python/geoarrow/rust/core/_rust.pyi

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,9 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
4-
from typing import Any, Sequence, Tuple, overload
3+
from typing import overload
54

6-
from arro3.core import Table
75
from arro3.core.types import ArrowArrayExportable, ArrowStreamExportable
86

9-
try:
10-
import numpy as np
11-
from numpy.typing import NDArray
12-
except ImportError:
13-
pass
14-
15-
try:
16-
import geopandas as gpd
17-
except ImportError:
18-
pass
19-
207
from ._array import GeoArray as GeoArray
218
from ._array_reader import GeoArrayReader as GeoArrayReader
229
from ._chunked_array import GeoChunkedArray as GeoChunkedArray
@@ -51,21 +38,34 @@ from ._interop import to_wkt as to_wkt
5138
from ._operations import get_type_id as get_type_id
5239
from ._scalar import GeoScalar as GeoScalar
5340

54-
# @overload
55-
# def geometry_col(input: ArrowArrayExportable) -> GeoArray: ...
56-
# @overload
57-
# def geometry_col(input: ArrowStreamExportable) -> GeoChunkedArray: ...
58-
# def geometry_col(
59-
# input: ArrowArrayExportable | ArrowStreamExportable,
60-
# ) -> GeoArray | GeoChunkedArray:
61-
# """Access the geometry column of a Table or RecordBatch
62-
63-
# Args:
64-
# input: The Arrow RecordBatch or Table to extract the geometry column from.
65-
66-
# Returns:
67-
# A geometry array or chunked array.
68-
# """
41+
@overload
42+
def geometry_col(
43+
input: ArrowArrayExportable,
44+
*,
45+
name: str | None = None,
46+
) -> GeoArray: ...
47+
@overload
48+
def geometry_col(
49+
input: ArrowStreamExportable,
50+
*,
51+
name: str | None = None,
52+
) -> GeoArrayReader: ...
53+
def geometry_col(
54+
input: ArrowArrayExportable | ArrowStreamExportable,
55+
*,
56+
name: str | None = None,
57+
) -> GeoArray | GeoArrayReader:
58+
"""Access the geometry column of a Table or RecordBatch
59+
60+
Args:
61+
input: The Arrow RecordBatch or Table to extract the geometry column from.
62+
63+
Keyword Args:
64+
name: The name of the geometry column to extract. If not provided, an error will be produced if there are multiple columns with GeoArrow metadata.
65+
66+
Returns:
67+
A geometry array or chunked array.
68+
"""
6969

7070
# Interop
7171

python/geoarrow-core/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
mod constructors;
44
mod interop;
55
mod operations;
6+
mod table;
67

78
use pyo3::exceptions::PyRuntimeWarning;
89
use pyo3::intern;
@@ -118,7 +119,7 @@ fn _rust(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
118119

119120
// Top-level table functions
120121

121-
// m.add_function(wrap_pyfunction!(crate::table::geometry_col, m)?)?;
122+
m.add_function(wrap_pyfunction!(crate::table::geometry_col, m)?)?;
122123

123124
// Interop
124125

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,90 @@
1-
use crate::ffi::to_python::{chunked_native_array_to_pyobject, native_array_to_pyobject};
2-
use crate::interop::util::pytable_to_table;
3-
use geoarrow::array::NativeArrayDyn;
4-
use geoarrow::schema::GeoSchemaExt;
5-
use pyo3::exceptions::PyNotImplementedError;
1+
use arrow_schema::Schema;
2+
use geoarrow_array::{GeoArrowArrayIterator, WrapArray};
3+
use geoarrow_schema::GeoArrowType;
4+
use pyo3::IntoPyObjectExt;
5+
use pyo3::exceptions::{PyIndexError, PyValueError};
66
use pyo3::prelude::*;
7+
use pyo3::pybacked::PyBackedStr;
78
use pyo3_arrow::input::AnyRecordBatch;
8-
use pyo3_geoarrow::PyGeoArrowResult;
9+
use pyo3_geoarrow::input::AnyGeoArray;
10+
use pyo3_geoarrow::{PyGeoArray, PyGeoArrayReader, PyGeoArrowResult};
911

1012
#[pyfunction]
11-
pub fn geometry_col(py: Python, input: AnyRecordBatch) -> PyGeoArrowResult<PyObject> {
12-
match input {
13-
AnyRecordBatch::RecordBatch(rb) => {
14-
let batch = rb.into_inner();
15-
let schema = batch.schema();
13+
#[pyo3(signature = (input, *, name = None))]
14+
pub fn geometry_col<'py>(
15+
py: Python<'py>,
16+
input: Bound<'py, PyAny>,
17+
name: Option<PyBackedStr>,
18+
) -> PyGeoArrowResult<Bound<'py, PyAny>> {
19+
// If the input is already a GeoArray, just return it
20+
if let Ok(input) = input.extract::<AnyGeoArray>() {
21+
match input {
22+
AnyGeoArray::Array(array) => {
23+
return Ok(array.into_bound_py_any(py)?);
24+
}
25+
AnyGeoArray::Stream(stream) => {
26+
return Ok(stream.into_bound_py_any(py)?);
27+
}
28+
}
29+
}
1630

17-
let geom_indices = schema.as_ref().geometry_columns();
18-
let index = if geom_indices.len() == 1 {
19-
geom_indices[0]
20-
} else {
21-
return Err(PyNotImplementedError::new_err(
22-
"Accessing from multiple geometry columns not yet supported",
23-
)
24-
.into());
25-
};
26-
27-
let field = schema.field(index);
28-
let array = batch.column(index).as_ref();
29-
let geo_arr = NativeArrayDyn::from_arrow_array(array, field)?.into_inner();
30-
native_array_to_pyobject(py, geo_arr)
31+
// Otherwise, assume it's a RecordBatch or RecordBatchStream
32+
let input = input.extract::<AnyRecordBatch>()?;
33+
34+
let schema = input.schema()?;
35+
36+
let (geom_index, geom_type) = if let Some(name) = name {
37+
let (idx, field) = schema
38+
.column_with_name(&name)
39+
.ok_or(PyIndexError::new_err(format!(
40+
"Column name {name} not found"
41+
)))?;
42+
43+
let geom_type = GeoArrowType::from_arrow_field(field)?;
44+
(idx, geom_type)
45+
} else {
46+
let geom_cols = geometry_columns(schema.as_ref());
47+
if geom_cols.is_empty() {
48+
return Err(PyValueError::new_err("No geometry columns found").into());
49+
} else if geom_cols.len() == 1 {
50+
geom_cols.into_iter().next().unwrap()
51+
} else {
52+
return Err(PyValueError::new_err(
53+
"Multiple geometry columns: 'name' must be provided.",
54+
)
55+
.into());
56+
}
57+
};
58+
59+
match input {
60+
AnyRecordBatch::RecordBatch(batch) => {
61+
let geo_array = geom_type.wrap_array(batch.as_ref().column(geom_index).as_ref())?;
62+
Ok(PyGeoArray::new(geo_array).into_bound_py_any(py)?)
3163
}
3264
AnyRecordBatch::Stream(stream) => {
33-
let table = stream.into_table()?;
34-
let table = pytable_to_table(table)?;
35-
let chunked_geom_arr = table.geometry_column(None)?;
36-
chunked_native_array_to_pyobject(py, chunked_geom_arr)
65+
let reader = stream.into_reader()?;
66+
let output_geo_type = geom_type.clone();
67+
let iter = reader
68+
.into_iter()
69+
.map(move |batch| geom_type.wrap_array(batch?.column(geom_index).as_ref()));
70+
let output_reader = Box::new(GeoArrowArrayIterator::new(iter, output_geo_type));
71+
72+
Ok(PyGeoArrayReader::new(output_reader).into_bound_py_any(py)?)
3773
}
3874
}
3975
}
76+
77+
fn geometry_columns(schema: &Schema) -> Vec<(usize, GeoArrowType)> {
78+
schema
79+
.fields()
80+
.iter()
81+
.enumerate()
82+
.filter_map(|(idx, field)| {
83+
if let Ok(geom_type) = GeoArrowType::from_extension_field(field) {
84+
Some((idx, geom_type))
85+
} else {
86+
None
87+
}
88+
})
89+
.collect()
90+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import geopandas as gpd
2+
import numpy as np
3+
import pytest
4+
import shapely
5+
from arro3.core import Table
6+
from geoarrow.rust.core import GeoArray, geometry_col
7+
8+
9+
def geoarrow_array():
10+
geoms = shapely.points([1, 2, 3], [4, 5, 6])
11+
return GeoArray.from_arrow(gpd.GeoSeries(geoms).to_arrow("geoarrow"))
12+
13+
14+
def test_batch_no_geom_cols():
15+
arr = np.array([1, 2, 3])
16+
# We should have simpler RecordBatch constructor
17+
# https://github.com/kylebarron/arro3/issues/418
18+
batch = Table.from_arrays([arr], names=["no_geom"]).to_batches()[0]
19+
with pytest.raises(ValueError, match="No geometry columns found"):
20+
geometry_col(batch)
21+
22+
23+
def test_batch_one_geom_col():
24+
arr = geoarrow_array()
25+
batch = Table.from_arrays([arr], names=["geom"]).to_batches()[0]
26+
fetched_arr = geometry_col(batch)
27+
assert arr == fetched_arr
28+
29+
30+
def test_batch_two_geom_cols():
31+
arr = geoarrow_array()
32+
batch = Table.from_arrays([arr, arr], names=["geom1", "geom2"]).to_batches()[0]
33+
with pytest.raises(ValueError, match="Multiple geometry columns"):
34+
geometry_col(batch)
35+
36+
assert geometry_col(batch, name="geom1") == arr
37+
assert geometry_col(batch, name="geom2") == arr
38+
39+
40+
def test_geo_array_input():
41+
arr = geoarrow_array()
42+
assert arr == geometry_col(arr)
43+
44+
45+
# TODO: implement once we have easy GeoChunkedArray constructor
46+
# def test_geo_chunked_array_input():
47+
# arr = geoarrow_array()
48+
# chunked = GeoChunkedArray.from_arrays([arr, arr])
49+
# assert chunked == geometry_col(chunked)
50+
51+
52+
def test_table_no_geom_cols():
53+
arr = np.array([1, 2, 3])
54+
table = Table.from_arrays([arr], names=["no_geom"])
55+
with pytest.raises(ValueError, match="No geometry columns found"):
56+
geometry_col(table)
57+
58+
59+
def test_table_one_geom_col():
60+
arr = geoarrow_array()
61+
table = Table.from_arrays([arr], names=["geom"])
62+
assert geometry_col(table).read_next_array() == arr
63+
64+
65+
def test_table_two_geom_cols():
66+
arr = geoarrow_array()
67+
table = Table.from_arrays([arr, arr], names=["geom1", "geom2"])
68+
with pytest.raises(ValueError, match="Multiple geometry columns"):
69+
geometry_col(table)
70+
71+
assert geometry_col(table, name="geom1").read_next_array() == arr
72+
assert geometry_col(table, name="geom2").read_next_array() == arr

rust/pyo3-geoarrow/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## [Unreleased]
44

5+
- Add `PyGeoArrowArrayReader::into_geoarrow_py`.
6+
57
## 0.6.1 - 2025-10-16
68

79
- docs(pyo3-geoarrow): Improve docs #1377

rust/pyo3-geoarrow/src/array.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,31 @@ impl PyGeoArray {
5151
self.0
5252
}
5353

54-
/// Export to a geoarrow.rust.core.GeoArrowArray.
54+
/// Export to a geoarrow.rust.core.GeoArray.
5555
///
5656
/// This requires that you depend on geoarrow-rust-core from your Python package.
5757
pub fn to_geoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
5858
let geoarrow_mod = py.import(intern!(py, "geoarrow.rust.core"))?;
59+
geoarrow_mod.getattr(intern!(py, "GeoArray"))?.call_method1(
60+
intern!(py, "from_arrow_pycapsule"),
61+
self.__arrow_c_array__(py, None)?,
62+
)
63+
}
64+
65+
/// Export to a geoarrow.rust.core.GeoArray.
66+
///
67+
/// This requires that you depend on geoarrow-rust-core from your Python package.
68+
pub fn into_geoarrow_py(self, py: Python) -> PyResult<Bound<PyAny>> {
69+
let geoarrow_mod = py.import(intern!(py, "geoarrow.rust.core"))?;
70+
let array_capsules = to_array_pycapsules(
71+
py,
72+
self.0.data_type().to_field("", true).into(),
73+
&self.0.to_array_ref(),
74+
None,
75+
)?;
5976
geoarrow_mod
60-
.getattr(intern!(py, "GeoArrowArray"))?
61-
.call_method1(
62-
intern!(py, "from_arrow_pycapsule"),
63-
self.__arrow_c_array__(py, None)?,
64-
)
77+
.getattr(intern!(py, "GeoArray"))?
78+
.call_method1(intern!(py, "from_arrow_pycapsule"), array_capsules)
6579
}
6680
}
6781

0 commit comments

Comments
 (0)