diff --git a/geoarrow-pyarrow/src/geoarrow/pyarrow/_compute.py b/geoarrow-pyarrow/src/geoarrow/pyarrow/_compute.py index c67dfc3..f5e74bb 100644 --- a/geoarrow-pyarrow/src/geoarrow/pyarrow/_compute.py +++ b/geoarrow-pyarrow/src/geoarrow/pyarrow/_compute.py @@ -180,11 +180,12 @@ def infer_type_common(obj, coord_type=None, promote_multi=False, _geometry_types if promote_multi and geometry_type.value in (1, 2, 3): geometry_type = GeometryType(geometry_type.value + 3) - spec = TypeSpec.coalesce( - type_spec(Encoding.GEOARROW, dims, geometry_type, coord_type=coord_type), - obj.type.spec, - ).canonicalize() + if geometry_type == GeometryType.GEOMETRY: + spec = type_spec(Encoding.WKB) + else: + spec = type_spec(Encoding.GEOARROW, dims, geometry_type, coord_type=coord_type) + spec = TypeSpec.coalesce(spec, obj.type.spec).canonicalize() return _type.extension_type(spec) diff --git a/geoarrow-pyarrow/src/geoarrow/pyarrow/_type.py b/geoarrow-pyarrow/src/geoarrow/pyarrow/_type.py index 2448795..fcb1ef0 100644 --- a/geoarrow-pyarrow/src/geoarrow/pyarrow/_type.py +++ b/geoarrow-pyarrow/src/geoarrow/pyarrow/_type.py @@ -161,4 +161,11 @@ def geometry_type_common( specs = [t.spec for t in type_objects] spec = types.TypeSpec.common(*specs).canonicalize() + + if ( + spec.encoding == types.Encoding.GEOARROW + and spec.geometry_type == types.GeometryType.GEOMETRY + ): + spec = types.TypeSpec.coalesce(types.wkb(), spec) + return extension_type(spec) diff --git a/geoarrow-types/src/geoarrow/types/type_pyarrow.py b/geoarrow-types/src/geoarrow/types/type_pyarrow.py index 64d9b0f..9cdcfb6 100644 --- a/geoarrow-types/src/geoarrow/types/type_pyarrow.py +++ b/geoarrow-types/src/geoarrow/types/type_pyarrow.py @@ -37,13 +37,17 @@ def __init__( ) if storage_type is None: - if spec.encoding == Encoding.GEOARROW: - key = spec.geometry_type, spec.coord_type, spec.dimensions + if self._spec.encoding == Encoding.GEOARROW: + key = ( + self._spec.geometry_type, + self._spec.coord_type, + self._spec.dimensions, + ) storage_type = _NATIVE_STORAGE_TYPES[key] else: - storage_type = _SERIALIZED_STORAGE_TYPES[spec.encoding] + storage_type = _SERIALIZED_STORAGE_TYPES[self._spec.encoding] elif validate_storage_type: - _validate_storage_type(storage_type, spec) + _validate_storage_type(storage_type, self._spec) pa.ExtensionType.__init__(self, storage_type, self._spec.extension_name()) @@ -253,6 +257,14 @@ class WktType(GeometryExtensionType): _extension_name = "geoarrow.wkt" +class GeometryUnionType(GeometryExtensionType): + _extension_name = "geoarrow.geometry" + + +class GeometryCollectionUnionType(GeometryExtensionType): + _extension_name = "geoarrow.geometrycollection" + + class PointType(GeometryExtensionType): """Extension type whose storage is an array of points stored as either a struct with one child per dimension or a fixed-size @@ -375,7 +387,7 @@ def from_geobuffers( def extension_type( spec: TypeSpec, storage_type=None, validate_storage_type=True ) -> GeometryExtensionType: - spec = spec.with_defaults() + spec = type_spec(spec).with_defaults() extension_cls = _EXTENSION_CLASSES[spec.extension_name()] return extension_cls( spec, storage_type=storage_type, validate_storage_type=validate_storage_type @@ -456,12 +468,14 @@ def register_extension_types(lazy: bool = True) -> None: all_types = [ type_spec(Encoding.WKT).to_pyarrow(), type_spec(Encoding.WKB).to_pyarrow(), + type_spec(Encoding.GEOARROW, GeometryType.GEOMETRY).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.POINT).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.LINESTRING).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.POLYGON).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.MULTIPOINT).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.MULTILINESTRING).to_pyarrow(), type_spec(Encoding.GEOARROW, GeometryType.MULTIPOLYGON).to_pyarrow(), + type_spec(Encoding.GEOARROW, GeometryType.GEOMETRYCOLLECTION).to_pyarrow(), ] n_registered = 0 @@ -533,6 +547,13 @@ def _parse_storage(storage_type): elif isinstance(storage_type, pa.ListType): f = storage_type.field(0) return [("list", (f.name,))] + _parse_storage(f.type) + elif isinstance(storage_type, pa.DenseUnionType): + n_fields = storage_type.num_fields + names = tuple(str(code) for code in storage_type.type_codes) + parsed_children = tuple( + _parse_storage(storage_type.field(i).type)[0] for i in range(n_fields) + ) + return [("dense_union", (names, parsed_children))] elif isinstance(storage_type, pa.StructType): n_fields = storage_type.num_fields names = tuple(storage_type.field(i).name for i in range(n_fields)) @@ -577,9 +598,12 @@ def _deserialize_storage(storage_type, extension_name=None, extension_metadata=N spec = _SPEC_FROM_TYPE_NESTING[parsed_type_names] spec = TypeSpec.from_extension_metadata(extension_metadata).with_defaults(spec) - # If this is a serialized type, we don't need to infer any more information - # from the storage type. - if spec.encoding.is_serialized(): + # If this is a serialized type or a union, we don't need to infer any more information + # from the storage type (because we don't currently validate union types). + if spec.encoding.is_serialized() or spec.geometry_type in ( + GeometryType.GEOMETRY, + GeometryType.GEOMETRYCOLLECTION, + ): if extension_name is not None and spec.extension_name() != extension_name: raise ValueError(f"Can't interpret {storage_type} as {extension_name}") @@ -742,6 +766,27 @@ def _from_buffers_multipolygon( ) +ALL_DIMENSIONS = [Dimensions.XY, Dimensions.XYZ, Dimensions.XYM, Dimensions.XYZM] +ALL_COORD_TYPES = [CoordType.INTERLEAVED, CoordType.SEPARATED] +ALL_GEOMETRY_TYPES = [ + GeometryType.POINT, + GeometryType.LINESTRING, + GeometryType.POLYGON, + GeometryType.MULTIPOINT, + GeometryType.MULTILINESTRING, + GeometryType.MULTIPOLYGON, + GeometryType.GEOMETRYCOLLECTION, +] +ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION = [ + GeometryType.POINT, + GeometryType.LINESTRING, + GeometryType.POLYGON, + GeometryType.MULTIPOINT, + GeometryType.MULTILINESTRING, + GeometryType.MULTIPOLYGON, +] + + def _generate_storage_types(): coord_storage = { (CoordType.SEPARATED, Dimensions.XY): _struct_fields("xy"), @@ -763,14 +808,10 @@ def _generate_storage_types(): GeometryType.MULTIPOLYGON: ["polygons", "rings", "vertices"], } - all_geoemetry_types = list(field_names.keys()) - all_coord_types = [CoordType.INTERLEAVED, CoordType.SEPARATED] - all_dimensions = [Dimensions.XY, Dimensions.XYZ, Dimensions.XYM, Dimensions.XYZM] - all_storage_types = {} - for geometry_type in all_geoemetry_types: - for coord_type in all_coord_types: - for dimensions in all_dimensions: + for geometry_type in ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION: + for coord_type in ALL_COORD_TYPES: + for dimensions in ALL_DIMENSIONS: names = field_names[geometry_type] coord = coord_storage[(coord_type, dimensions)] key = geometry_type, coord_type, dimensions @@ -780,6 +821,81 @@ def _generate_storage_types(): return all_storage_types +def _generate_union_storage( + geometry_types=ALL_GEOMETRY_TYPES, + dimensions=ALL_DIMENSIONS, + coord_type=CoordType.SEPARATED, +): + child_fields = [] + type_codes = [] + for dimension in dimensions: + for geometry_type in geometry_types: + spec = type_spec( + encoding=Encoding.GEOARROW, + geometry_type=geometry_type, + dimensions=dimension, + coord_type=coord_type, + ) + + if spec.geometry_type == GeometryType.GEOMETRYCOLLECTION: + storage_type = _generate_union_collection_storage( + spec.dimensions, coord_type + ) + else: + storage_type = extension_type(spec).storage_type + + type_id = _UNION_TYPE_ID_FROM_SPEC[(spec.geometry_type, spec.dimensions)] + geometry_type_lab = _UNION_GEOMETRY_TYPE_LABELS[spec.geometry_type.value] + dimension_lab = _UNION_DIMENSION_LABELS[spec.dimensions.value] + + child_fields.append( + pa.field(f"{geometry_type_lab}{dimension_lab}", storage_type) + ) + type_codes.append(type_id) + + return pa.dense_union(child_fields, type_codes) + + +def _generate_union_collection_storage(dimensions, coord_type): + storage_union = _generate_union_storage( + geometry_types=ALL_GEOMETRY_TYPES_EXCEPT_GEOMETRYCOLLECTION, + dimensions=[dimensions], + coord_type=coord_type, + ) + storage_union_field = pa.field("geometries", storage_union, nullable=False) + return pa.list_(storage_union_field) + + +def _generate_union_type_id_mapping(): + out = {} + for dimension in ALL_DIMENSIONS: + for geometry_type in ALL_GEOMETRY_TYPES: + type_id = (dimension.value - 1) * 10 + geometry_type.value + out[type_id] = (geometry_type, dimension) + return out + + +def _add_union_types_to_native_storage_types(): + global _NATIVE_STORAGE_TYPES + + for coord_type in ALL_COORD_TYPES: + for dimension in ALL_DIMENSIONS: + _NATIVE_STORAGE_TYPES[ + (GeometryType.GEOMETRY, coord_type, dimension) + ] = _generate_union_storage(coord_type=coord_type, dimensions=[dimension]) + + # With unknown dimensions, we reigster the massive catch-all union + _NATIVE_STORAGE_TYPES[ + (GeometryType.GEOMETRY, coord_type, Dimensions.UNKNOWN) + ] = _generate_union_storage(coord_type=coord_type) + + for coord_type in ALL_COORD_TYPES: + for dimension in ALL_DIMENSIONS: + _NATIVE_STORAGE_TYPES[ + (GeometryType.GEOMETRYCOLLECTION, coord_type, dimension) + ] = _generate_union_collection_storage(dimension, coord_type) + + # A shorter version of repr(spec) that matches what geoarrow-c used to do # (to reduce mayhem on docstring updates). def _spec_short_repr(spec, ext_name): @@ -819,13 +935,32 @@ def _spec_short_repr(spec, ext_name): "geoarrow.wkb": WkbType, "geoarrow.wkt": WktType, "geoarrow.point": PointType, + "geoarrow.geometry": GeometryUnionType, "geoarrow.linestring": LinestringType, "geoarrow.polygon": PolygonType, "geoarrow.multipoint": MultiPointType, "geoarrow.multilinestring": MultiLinestringType, "geoarrow.multipolygon": MultiPolygonType, + "geoarrow.geometrycollection": GeometryCollectionUnionType, } + +_SPEC_FROM_UNION_TYPE_ID = _generate_union_type_id_mapping() +_UNION_TYPE_ID_FROM_SPEC = {v: k for k, v in _SPEC_FROM_UNION_TYPE_ID.items()} + +_UNION_GEOMETRY_TYPE_LABELS = [ + "Geometry", + "Point", + "LineString", + "Polygon", + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", +] + +_UNION_DIMENSION_LABELS = [None, "", " Z", " M", " ZM"] + _SERIALIZED_STORAGE_TYPES = { Encoding.WKT: pa.utf8(), Encoding.LARGE_WKT: pa.large_utf8(), @@ -834,6 +969,7 @@ def _spec_short_repr(spec, ext_name): } _NATIVE_STORAGE_TYPES = _generate_storage_types() +_add_union_types_to_native_storage_types() _SPEC_FROM_TYPE_NESTING = { ("binary",): Encoding.WKB, @@ -845,6 +981,10 @@ def _spec_short_repr(spec, ext_name): geometry_type=GeometryType.POINT, coord_type=CoordType.SEPARATED, ), + ("dense_union",): TypeSpec( + encoding=Encoding.GEOARROW, + geometry_type=GeometryType.GEOMETRY, + ), ("list", "struct"): TypeSpec( encoding=Encoding.GEOARROW, coord_type=CoordType.SEPARATED ), @@ -872,6 +1012,10 @@ def _spec_short_repr(spec, ext_name): coord_type=CoordType.INTERLEAVED, geometry_type=GeometryType.MULTIPOLYGON, ), + ("list", "dense_union"): TypeSpec( + encoding=Encoding.GEOARROW, + geometry_type=GeometryType.GEOMETRYCOLLECTION, + ), } _DIMS_FROM_NAMES = { diff --git a/geoarrow-types/src/geoarrow/types/type_spec.py b/geoarrow-types/src/geoarrow/types/type_spec.py index b094687..ee07afe 100644 --- a/geoarrow-types/src/geoarrow/types/type_spec.py +++ b/geoarrow-types/src/geoarrow/types/type_spec.py @@ -132,13 +132,10 @@ def canonicalize(self): If this type specification represents a serialized type, ensure that the dimensions are UNKNOWN, the geometry type is GEOMETRY, - and the coord type is UNSPECIFIED. Conversely, when geometry - type is UNKNOWN, the geometry type can't be guessed and we - need to set the encoding to a serialized type. + and the coord type is UNSPECIFIED. - These ensure that when a type - implementation needs to construct a concrete type that its - components are represented consistently. + These ensure that when a type implementation needs to construct a + concrete type that its components are represented consistently. """ if self.encoding.is_serialized(): return self.override( @@ -146,8 +143,6 @@ def canonicalize(self): dimensions=Dimensions.UNKNOWN, coord_type=CoordType.UNSPECIFIED, ) - elif self.geometry_type == GeometryType.GEOMETRY: - return self.override(encoding=Encoding.WKB).canonicalize() else: return self @@ -604,12 +599,14 @@ def type_spec( } _GEOARROW_EXT_NAMES = { + GeometryType.GEOMETRY: "geoarrow.geometry", GeometryType.POINT: "geoarrow.point", GeometryType.LINESTRING: "geoarrow.linestring", GeometryType.POLYGON: "geoarrow.polygon", GeometryType.MULTIPOINT: "geoarrow.multipoint", GeometryType.MULTILINESTRING: "geoarrow.multilinestring", GeometryType.MULTIPOLYGON: "geoarrow.multipolygon", + GeometryType.GEOMETRYCOLLECTION: "geoarrow.geometrycollection", } _GEOMETRY_TYPE_FROM_EXT = {v: k for k, v in _GEOARROW_EXT_NAMES.items()} diff --git a/geoarrow-types/tests/test_type_pyarrow.py b/geoarrow-types/tests/test_type_pyarrow.py index 19861b5..8c0da33 100644 --- a/geoarrow-types/tests/test_type_pyarrow.py +++ b/geoarrow-types/tests/test_type_pyarrow.py @@ -285,6 +285,22 @@ def test_deserialize_infer_dimensions_interleaved(): ) +def test_geometry_union_type(): + geometry = gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRY).to_pyarrow() + assert isinstance(geometry, type_pyarrow.GeometryUnionType) + assert geometry.encoding == gt.Encoding.GEOARROW + assert geometry.geometry_type == gt.GeometryType.GEOMETRY + + +def test_geometry_collection_union_type(): + geometry = gt.type_spec( + gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRYCOLLECTION + ).to_pyarrow() + assert isinstance(geometry, type_pyarrow.GeometryCollectionUnionType) + assert geometry.encoding == gt.Encoding.GEOARROW + assert geometry.geometry_type == gt.GeometryType.GEOMETRYCOLLECTION + + def test_point_array_from_geobuffers(): pa_type = gt.point(dimensions=gt.Dimensions.XYZM).to_pyarrow() arr = pa_type.from_geobuffers( @@ -417,6 +433,9 @@ def test_multipolygon_array_from_geobuffers(): gt.point(dimensions="xyz", coord_type="interleaved"), gt.point(dimensions="xym", coord_type="interleaved"), gt.point(dimensions="xyzm", coord_type="interleaved"), + # Union types + gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRY), + gt.type_spec(gt.Encoding.GEOARROW, gt.GeometryType.GEOMETRYCOLLECTION), ], ) def test_roundtrip_extension_type(spec):