Skip to content
9 changes: 5 additions & 4 deletions geoarrow-pyarrow/src/geoarrow/pyarrow/_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ def infer_type_common(obj, coord_type=None, promote_multi=False, _geometry_types
if promote_multi and geometry_type.value in (1, 2, 3):
geometry_type = GeometryType(geometry_type.value + 3)

spec = TypeSpec.coalesce(
type_spec(Encoding.GEOARROW, dims, geometry_type, coord_type=coord_type),
obj.type.spec,
).canonicalize()
if geometry_type == GeometryType.GEOMETRY:
spec = type_spec(Encoding.WKB)
else:
spec = type_spec(Encoding.GEOARROW, dims, geometry_type, coord_type=coord_type)

spec = TypeSpec.coalesce(spec, obj.type.spec).canonicalize()
return _type.extension_type(spec)


Expand Down
7 changes: 7 additions & 0 deletions geoarrow-pyarrow/src/geoarrow/pyarrow/_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,11 @@ def geometry_type_common(

specs = [t.spec for t in type_objects]
spec = types.TypeSpec.common(*specs).canonicalize()

if (
spec.encoding == types.Encoding.GEOARROW
and spec.geometry_type == types.GeometryType.GEOMETRY
):
spec = types.TypeSpec.coalesce(types.wkb(), spec)

return extension_type(spec)
174 changes: 159 additions & 15 deletions geoarrow-types/src/geoarrow/types/type_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ def __init__(
)

if storage_type is None:
if spec.encoding == Encoding.GEOARROW:
key = spec.geometry_type, spec.coord_type, spec.dimensions
if self._spec.encoding == Encoding.GEOARROW:
key = (
self._spec.geometry_type,
self._spec.coord_type,
self._spec.dimensions,
)
storage_type = _NATIVE_STORAGE_TYPES[key]
else:
storage_type = _SERIALIZED_STORAGE_TYPES[spec.encoding]
storage_type = _SERIALIZED_STORAGE_TYPES[self._spec.encoding]
elif validate_storage_type:
_validate_storage_type(storage_type, spec)
_validate_storage_type(storage_type, self._spec)

pa.ExtensionType.__init__(self, storage_type, self._spec.extension_name())

Expand Down Expand Up @@ -253,6 +257,14 @@ class WktType(GeometryExtensionType):
_extension_name = "geoarrow.wkt"


class GeometryUnionType(GeometryExtensionType):
_extension_name = "geoarrow.geometry"


class GeometryCollectionUnionType(GeometryExtensionType):
_extension_name = "geoarrow.geometrycollection"


class PointType(GeometryExtensionType):
"""Extension type whose storage is an array of points stored
as either a struct with one child per dimension or a fixed-size
Expand Down Expand Up @@ -375,7 +387,7 @@ def from_geobuffers(
def extension_type(
spec: TypeSpec, storage_type=None, validate_storage_type=True
) -> GeometryExtensionType:
spec = spec.with_defaults()
spec = type_spec(spec).with_defaults()
extension_cls = _EXTENSION_CLASSES[spec.extension_name()]
return extension_cls(
spec, storage_type=storage_type, validate_storage_type=validate_storage_type
Expand Down Expand Up @@ -456,12 +468,14 @@ def register_extension_types(lazy: bool = True) -> None:
all_types = [
type_spec(Encoding.WKT).to_pyarrow(),
type_spec(Encoding.WKB).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.GEOMETRY).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.POINT).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.LINESTRING).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.POLYGON).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.MULTIPOINT).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.MULTILINESTRING).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.MULTIPOLYGON).to_pyarrow(),
type_spec(Encoding.GEOARROW, GeometryType.GEOMETRYCOLLECTION).to_pyarrow(),
]

n_registered = 0
Expand Down Expand Up @@ -533,6 +547,13 @@ def _parse_storage(storage_type):
elif isinstance(storage_type, pa.ListType):
f = storage_type.field(0)
return [("list", (f.name,))] + _parse_storage(f.type)
elif isinstance(storage_type, pa.DenseUnionType):
n_fields = storage_type.num_fields
names = tuple(str(code) for code in storage_type.type_codes)
parsed_children = tuple(
_parse_storage(storage_type.field(i).type)[0] for i in range(n_fields)
)
return [("dense_union", (names, parsed_children))]
elif isinstance(storage_type, pa.StructType):
n_fields = storage_type.num_fields
names = tuple(storage_type.field(i).name for i in range(n_fields))
Expand Down Expand Up @@ -577,9 +598,12 @@ def _deserialize_storage(storage_type, extension_name=None, extension_metadata=N
spec = _SPEC_FROM_TYPE_NESTING[parsed_type_names]
spec = TypeSpec.from_extension_metadata(extension_metadata).with_defaults(spec)

# If this is a serialized type, we don't need to infer any more information
# from the storage type.
if spec.encoding.is_serialized():
# If this is a serialized type or a union, we don't need to infer any more information
# from the storage type (because we don't currently validate union types).
if spec.encoding.is_serialized() or spec.geometry_type in (
GeometryType.GEOMETRY,
GeometryType.GEOMETRYCOLLECTION,
):
if extension_name is not None and spec.extension_name() != extension_name:
raise ValueError(f"Can't interpret {storage_type} as {extension_name}")

Expand Down Expand Up @@ -742,6 +766,27 @@ def _from_buffers_multipolygon(
)


ALL_DIMENSIONS = [Dimensions.XY, Dimensions.XYZ, Dimensions.XYM, Dimensions.XYZM]
ALL_COORD_TYPES = [CoordType.INTERLEAVED, CoordType.SEPARATED]
ALL_GEOMETRY_TYPES = [
GeometryType.POINT,
GeometryType.LINESTRING,
GeometryType.POLYGON,
GeometryType.MULTIPOINT,
GeometryType.MULTILINESTRING,
GeometryType.MULTIPOLYGON,
GeometryType.GEOMETRYCOLLECTION,
]
ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION = [
GeometryType.POINT,
GeometryType.LINESTRING,
GeometryType.POLYGON,
GeometryType.MULTIPOINT,
GeometryType.MULTILINESTRING,
GeometryType.MULTIPOLYGON,
]


def _generate_storage_types():
coord_storage = {
(CoordType.SEPARATED, Dimensions.XY): _struct_fields("xy"),
Expand All @@ -763,14 +808,10 @@ def _generate_storage_types():
GeometryType.MULTIPOLYGON: ["polygons", "rings", "vertices"],
}

all_geoemetry_types = list(field_names.keys())
all_coord_types = [CoordType.INTERLEAVED, CoordType.SEPARATED]
all_dimensions = [Dimensions.XY, Dimensions.XYZ, Dimensions.XYM, Dimensions.XYZM]

all_storage_types = {}
for geometry_type in all_geoemetry_types:
for coord_type in all_coord_types:
for dimensions in all_dimensions:
for geometry_type in ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION:
for coord_type in ALL_COORD_TYPES:
for dimensions in ALL_DIMENSIONS:
names = field_names[geometry_type]
coord = coord_storage[(coord_type, dimensions)]
key = geometry_type, coord_type, dimensions
Expand All @@ -780,6 +821,81 @@ def _generate_storage_types():
return all_storage_types


def _generate_union_storage(
geometry_types=ALL_GEOMETRY_TYPES,
dimensions=ALL_DIMENSIONS,
coord_type=CoordType.SEPARATED,
):
child_fields = []
type_codes = []
for dimension in dimensions:
for geometry_type in geometry_types:
spec = type_spec(
encoding=Encoding.GEOARROW,
geometry_type=geometry_type,
dimensions=dimension,
coord_type=coord_type,
)

if spec.geometry_type == GeometryType.GEOMETRYCOLLECTION:
storage_type = _generate_union_collection_storage(
spec.dimensions, coord_type
)
else:
storage_type = extension_type(spec).storage_type

type_id = _UNION_TYPE_ID_FROM_SPEC[(spec.geometry_type, spec.dimensions)]
geometry_type_lab = _UNION_GEOMETRY_TYPE_LABELS[spec.geometry_type.value]
dimension_lab = _UNION_DIMENSION_LABELS[spec.dimensions.value]

child_fields.append(
pa.field(f"{geometry_type_lab}{dimension_lab}", storage_type)
)
type_codes.append(type_id)

return pa.dense_union(child_fields, type_codes)


def _generate_union_collection_storage(dimensions, coord_type):
storage_union = _generate_union_storage(
geometry_types=ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION,
dimensions=[dimensions],
coord_type=coord_type,
)
storage_union_field = pa.field("geometries", storage_union, nullable=False)
return pa.list_(storage_union_field)


def _generate_union_type_id_mapping():
out = {}
for dimension in ALL_DIMENSIONS:
for geometry_type in ALL_GEOMETRY_TYPES:
type_id = (dimension.value - 1) * 10 + geometry_type.value
out[type_id] = (geometry_type, dimension)
return out


def _add_union_types_to_native_storage_types():
global _NATIVE_STORAGE_TYPES

for coord_type in ALL_COORD_TYPES:
for dimension in ALL_DIMENSIONS:
_NATIVE_STORAGE_TYPES[
(GeometryType.GEOMETRY, coord_type, dimension)
] = _generate_union_storage(coord_type=coord_type, dimensions=[dimension])

# With unknown dimensions, we reigster the massive catch-all union
_NATIVE_STORAGE_TYPES[
(GeometryType.GEOMETRY, coord_type, Dimensions.UNKNOWN)
] = _generate_union_storage(coord_type=coord_type)

for coord_type in ALL_COORD_TYPES:
for dimension in ALL_DIMENSIONS:
_NATIVE_STORAGE_TYPES[
(GeometryType.GEOMETRYCOLLECTION, coord_type, dimension)
] = _generate_union_collection_storage(dimension, coord_type)


# A shorter version of repr(spec) that matches what geoarrow-c used to do
# (to reduce mayhem on docstring updates).
def _spec_short_repr(spec, ext_name):
Expand Down Expand Up @@ -819,13 +935,32 @@ def _spec_short_repr(spec, ext_name):
"geoarrow.wkb": WkbType,
"geoarrow.wkt": WktType,
"geoarrow.point": PointType,
"geoarrow.geometry": GeometryUnionType,
"geoarrow.linestring": LinestringType,
"geoarrow.polygon": PolygonType,
"geoarrow.multipoint": MultiPointType,
"geoarrow.multilinestring": MultiLinestringType,
"geoarrow.multipolygon": MultiPolygonType,
"geoarrow.geometrycollection": GeometryCollectionUnionType,
}


_SPEC_FROM_UNION_TYPE_ID = _generate_union_type_id_mapping()
_UNION_TYPE_ID_FROM_SPEC = {v: k for k, v in _SPEC_FROM_UNION_TYPE_ID.items()}

_UNION_GEOMETRY_TYPE_LABELS = [
"Geometry",
"Point",
"LineString",
"Polygon",
"MultiPoint",
"MultiLineString",
"MultiPolygon",
"GeometryCollection",
]

_UNION_DIMENSION_LABELS = [None, "", " Z", " M", " ZM"]

_SERIALIZED_STORAGE_TYPES = {
Encoding.WKT: pa.utf8(),
Encoding.LARGE_WKT: pa.large_utf8(),
Expand All @@ -834,6 +969,7 @@ def _spec_short_repr(spec, ext_name):
}

_NATIVE_STORAGE_TYPES = _generate_storage_types()
_add_union_types_to_native_storage_types()

_SPEC_FROM_TYPE_NESTING = {
("binary",): Encoding.WKB,
Expand All @@ -845,6 +981,10 @@ def _spec_short_repr(spec, ext_name):
geometry_type=GeometryType.POINT,
coord_type=CoordType.SEPARATED,
),
("dense_union",): TypeSpec(
encoding=Encoding.GEOARROW,
geometry_type=GeometryType.GEOMETRY,
),
("list", "struct"): TypeSpec(
encoding=Encoding.GEOARROW, coord_type=CoordType.SEPARATED
),
Expand Down Expand Up @@ -872,6 +1012,10 @@ def _spec_short_repr(spec, ext_name):
coord_type=CoordType.INTERLEAVED,
geometry_type=GeometryType.MULTIPOLYGON,
),
("list", "dense_union"): TypeSpec(
encoding=Encoding.GEOARROW,
geometry_type=GeometryType.GEOMETRYCOLLECTION,
),
}

_DIMS_FROM_NAMES = {
Expand Down
13 changes: 5 additions & 8 deletions geoarrow-types/src/geoarrow/types/type_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,17 @@ def canonicalize(self):

If this type specification represents a serialized type, ensure
that the dimensions are UNKNOWN, the geometry type is GEOMETRY,
and the coord type is UNSPECIFIED. Conversely, when geometry
type is UNKNOWN, the geometry type can't be guessed and we
need to set the encoding to a serialized type.
and the coord type is UNSPECIFIED.

These ensure that when a type
implementation needs to construct a concrete type that its
components are represented consistently.
These ensure that when a type implementation needs to construct a
concrete type that its components are represented consistently.
"""
if self.encoding.is_serialized():
return self.override(
geometry_type=GeometryType.GEOMETRY,
dimensions=Dimensions.UNKNOWN,
coord_type=CoordType.UNSPECIFIED,
)
elif self.geometry_type == GeometryType.GEOMETRY:
return self.override(encoding=Encoding.WKB).canonicalize()
else:
return self

Expand Down Expand Up @@ -604,12 +599,14 @@ def type_spec(
}

_GEOARROW_EXT_NAMES = {
GeometryType.GEOMETRY: "geoarrow.geometry",
GeometryType.POINT: "geoarrow.point",
GeometryType.LINESTRING: "geoarrow.linestring",
GeometryType.POLYGON: "geoarrow.polygon",
GeometryType.MULTIPOINT: "geoarrow.multipoint",
GeometryType.MULTILINESTRING: "geoarrow.multilinestring",
GeometryType.MULTIPOLYGON: "geoarrow.multipolygon",
GeometryType.GEOMETRYCOLLECTION: "geoarrow.geometrycollection",
}

_GEOMETRY_TYPE_FROM_EXT = {v: k for k, v in _GEOARROW_EXT_NAMES.items()}
19 changes: 19 additions & 0 deletions geoarrow-types/tests/test_type_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def test_deserialize_infer_dimensions_interleaved():
)


def test_geometry_union_type():
geometry = gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRY).to_pyarrow()
assert isinstance(geometry, type_pyarrow.GeometryUnionType)
assert geometry.encoding == gt.Encoding.GEOARROW
assert geometry.geometry_type == gt.GeometryType.GEOMETRY


def test_geometry_collection_union_type():
geometry = gt.type_spec(
gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRYCOLLECTION
).to_pyarrow()
assert isinstance(geometry, type_pyarrow.GeometryCollectionUnionType)
assert geometry.encoding == gt.Encoding.GEOARROW
assert geometry.geometry_type == gt.GeometryType.GEOMETRYCOLLECTION


def test_point_array_from_geobuffers():
pa_type = gt.point(dimensions=gt.Dimensions.XYZM).to_pyarrow()
arr = pa_type.from_geobuffers(
Expand Down Expand Up @@ -417,6 +433,9 @@ def test_multipolygon_array_from_geobuffers():
gt.point(dimensions="xyz", coord_type="interleaved"),
gt.point(dimensions="xym", coord_type="interleaved"),
gt.point(dimensions="xyzm", coord_type="interleaved"),
# Union types
gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRY),
gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRYCOLLECTION),
],
)
def test_roundtrip_extension_type(spec):
Expand Down
Loading