Skip to content

Commit 1d6491f

Browse files
authored
Support geoarrow array input into viz() (#427)
Closes #425, blocked on apache/arrow#38010 (comment). The main issue is that we need a reliable way to maintain the geoarrow extension metadata through FFI. The easiest way would be if `pa.field()` were able to support `__arrow_c_schema__` input. Or alternatively, one option is to have a context manager of sorts to register global pyarrow geoarrow extension arrays, and then deregister them after use.
1 parent 5bdd908 commit 1d6491f

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

lonboard/_viz.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ class GeoInterfaceProtocol(Protocol):
4646
@property
4747
def __geo_interface__(self) -> dict: ...
4848

49+
class ArrowArrayExportable(Protocol):
50+
def __arrow_c_array__(
51+
self, requested_schema: object | None = None
52+
) -> Tuple[object, object]: ...
53+
4954
class ArrowStreamExportable(Protocol):
5055
def __arrow_c_stream__(
5156
self, requested_schema: object | None = None
@@ -187,6 +192,11 @@ def create_layer_from_data_input(
187192
if isinstance(data, shapely.geometry.base.BaseGeometry):
188193
return _viz_shapely_scalar(data, **kwargs)
189194

195+
# Anything with __arrow_c_array__
196+
if hasattr(data, "__arrow_c_array__"):
197+
data = cast("ArrowArrayExportable", data)
198+
return _viz_geoarrow_array(data, **kwargs)
199+
190200
# Anything with __arrow_c_stream__
191201
if hasattr(data, "__arrow_c_stream__"):
192202
data = cast("ArrowStreamExportable", data)
@@ -296,6 +306,52 @@ def _viz_geo_interface(
296306
raise ValueError(f"type '{geo_interface_type}' not supported.")
297307

298308

309+
def _viz_geoarrow_array(
310+
data: ArrowArrayExportable,
311+
**kwargs,
312+
) -> Union[ScatterplotLayer, PathLayer, SolidPolygonLayer]:
313+
schema_capsule, array_capsule = data.__arrow_c_array__()
314+
315+
# If the user doesn't have pyarrow extension types registered for geoarrow types,
316+
# `pa.array()` will lose the extension metadata. Instead, we manually persist the
317+
# extension metadata by extracting both the field and the array.
318+
319+
class ArrayHolder:
320+
schema_capsule: object
321+
array_capsule: object
322+
323+
def __init__(self, schema_capsule, array_capsule) -> None:
324+
self.schema_capsule = schema_capsule
325+
self.array_capsule = array_capsule
326+
327+
def __arrow_c_array__(self, requested_schema):
328+
return self.schema_capsule, self.array_capsule
329+
330+
if not hasattr(pa.Field, "_import_from_c_capsule"):
331+
raise KeyError(
332+
"Incompatible version of pyarrow: pa.Field does not have"
333+
" _import_from_c_capsule method"
334+
)
335+
336+
field = pa.Field._import_from_c_capsule(schema_capsule)
337+
array = pa.array(ArrayHolder(field.__arrow_c_schema__(), array_capsule))
338+
schema = pa.schema([field.with_name("geometry")])
339+
table = pa.Table.from_arrays([array], schema=schema)
340+
341+
num_rows = len(array)
342+
if num_rows <= np.iinfo(np.uint8).max:
343+
arange_col = np.arange(num_rows, dtype=np.uint8)
344+
elif num_rows <= np.iinfo(np.uint16).max:
345+
arange_col = np.arange(num_rows, dtype=np.uint16)
346+
elif num_rows <= np.iinfo(np.uint32).max:
347+
arange_col = np.arange(num_rows, dtype=np.uint32)
348+
else:
349+
arange_col = np.arange(num_rows, dtype=np.uint64)
350+
351+
table = table.append_column("row_index", pa.array(arange_col))
352+
return _viz_geoarrow_table(table, **kwargs)
353+
354+
299355
def _viz_geoarrow_table(
300356
table: pa.Table,
301357
*,

tests/test_viz.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import geodatasets
22
import geopandas as gpd
3+
from geoarrow.rust.core import read_pyogrio
34
from pyogrio.raw import read_arrow
45

56
from lonboard import SolidPolygonLayer, viz
67

78

8-
def test_viz_wkb_geoarrow():
9+
def test_viz_wkb_pyarrow():
910
path = geodatasets.get_path("naturalearth.land")
1011
meta, table = read_arrow(path)
1112
map_ = viz(table)
@@ -37,3 +38,23 @@ def __geo_interface__(self):
3738
map_ = viz(geo_interface_obj)
3839

3940
assert isinstance(map_.layers[0], SolidPolygonLayer)
41+
42+
43+
def test_viz_geoarrow_rust_table():
44+
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
45+
map_ = viz(table)
46+
assert isinstance(map_.layers[0], SolidPolygonLayer)
47+
48+
49+
def test_viz_geoarrow_rust_array():
50+
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
51+
map_ = viz(table.geometry.chunk(0))
52+
assert isinstance(map_.layers[0], SolidPolygonLayer)
53+
54+
55+
def test_viz_geoarrow_rust_wkb_array():
56+
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
57+
arr = table.geometry.chunk(0)
58+
wkb_arr = arr.to_wkb()
59+
map_ = viz(wkb_arr)
60+
assert isinstance(map_.layers[0], SolidPolygonLayer)

0 commit comments

Comments
 (0)