Skip to content

Commit 89b02d9

Browse files
ENH: support mask for missing values when writing data / support writing pandas nullable dtypes (#232)
1 parent 976a705 commit 89b02d9

File tree

6 files changed

+86
-3
lines changed

6 files changed

+86
-3
lines changed

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- Add "driver" property to `read_info` result (#224)
99
- Add support for dataset open options to `read`, `read_dataframe`, and
1010
`read_info` (#233)
11+
- Add support for pandas' nullable data types in `write_dataframe`, or
12+
specifying a mask manually for missing values in `write` (#219)
1113
- Standardized 3-dimensional geometry type labels from "2.5D <type>" to
1214
"<type> Z" for consistency with well-known text (WKT) formats (#234)
1315

pyogrio/_io.pyx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,7 @@ cdef infer_field_types(list dtypes):
14041404

14051405
# TODO: set geometry and field data as memory views?
14061406
def ogr_write(
1407-
str path, str layer, str driver, geometry, field_data, fields,
1407+
str path, str layer, str driver, geometry, fields, field_data, field_mask,
14081408
str crs, str geometry_type, str encoding, object dataset_kwargs,
14091409
object layer_kwargs, bint promote_to_multi=False, bint nan_as_null=True,
14101410
bint append=False
@@ -1442,6 +1442,15 @@ def ogr_write(
14421442
if len(field_data[i]) != num_records:
14431443
raise ValueError("field_data arrays must be same length as geometry array")
14441444

1445+
if field_mask is not None:
1446+
if len(field_data) != len(field_mask):
1447+
raise ValueError("field_data and field_mask must be same length")
1448+
for i in range(0, len(field_mask)):
1449+
if field_mask[i] is not None and len(field_mask[i]) != num_records:
1450+
raise ValueError("field_mask arrays must be same length as geometry array")
1451+
else:
1452+
field_mask = [None] * len(field_data)
1453+
14451454
path_b = path.encode('UTF-8')
14461455
path_c = path_b
14471456

@@ -1658,7 +1667,11 @@ def ogr_write(
16581667
field_value = field_data[field_idx][i]
16591668
field_type = field_types[field_idx][0]
16601669

1661-
if field_type == OFTString:
1670+
mask = field_mask[field_idx]
1671+
if mask is not None and mask[i]:
1672+
OGR_F_SetFieldNull(ogr_feature, field_idx)
1673+
1674+
elif field_type == OFTString:
16621675
# TODO: encode string using approach from _get_internal_encoding which checks layer capabilities
16631676
if (
16641677
field_value is None

pyogrio/geopandas.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,22 @@ def write_dataframe(
311311
fields = [c for c in df.columns if not c == geometry_column]
312312

313313
# TODO: may need to fill in pd.NA, etc
314-
field_data = [df[f].values for f in fields]
314+
field_data = []
315+
field_mask = []
316+
for name in fields:
317+
col = df[name].values
318+
if isinstance(col, pd.api.extensions.ExtensionArray):
319+
from pandas.arrays import IntegerArray, FloatingArray, BooleanArray
320+
321+
if isinstance(col, (IntegerArray, FloatingArray, BooleanArray)):
322+
field_data.append(col._data)
323+
field_mask.append(col._mask)
324+
else:
325+
field_data.append(np.asarray(col))
326+
field_mask.append(np.asarray(col.isna()))
327+
else:
328+
field_data.append(col)
329+
field_mask.append(None)
315330

316331
# Determine geometry_type and/or promote_to_multi
317332
if geometry_type is None or promote_to_multi is None:
@@ -386,6 +401,7 @@ def write_dataframe(
386401
driver=driver,
387402
geometry=to_wkb(geometry.values),
388403
field_data=field_data,
404+
field_mask=field_mask,
389405
fields=fields,
390406
crs=crs,
391407
geometry_type=geometry_type,

pyogrio/raw.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def write(
278278
geometry,
279279
field_data,
280280
fields,
281+
field_mask=None,
281282
layer=None,
282283
driver=None,
283284
# derived from meta if roundtrip
@@ -349,6 +350,7 @@ def write(
349350
geometry=geometry,
350351
geometry_type=geometry_type,
351352
field_data=field_data,
353+
field_mask=field_mask,
352354
fields=fields,
353355
crs=crs,
354356
encoding=encoding,

pyogrio/tests/test_geopandas_io.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,25 @@ def test_read_dataset_kwargs(data_dir, use_arrow):
10221022
def test_read_invalid_dataset_kwargs(capfd, naturalearth_lowres, use_arrow):
10231023
read_dataframe(naturalearth_lowres, use_arrow=use_arrow, INVALID="YES")
10241024
assert "does not support open option INVALID" in capfd.readouterr().err
1025+
1026+
1027+
def test_write_nullable_dtypes(tmp_path):
1028+
path = tmp_path / "test_nullable_dtypes.gpkg"
1029+
test_data = {
1030+
"col1": pd.Series([1, 2, 3], dtype="int64"),
1031+
"col2": pd.Series([1, 2, None], dtype="Int64"),
1032+
"col3": pd.Series([0.1, None, 0.3], dtype="Float32"),
1033+
"col4": pd.Series([True, False, None], dtype="boolean"),
1034+
"col5": pd.Series(["a", None, "b"], dtype="string"),
1035+
}
1036+
input_gdf = gp.GeoDataFrame(test_data, geometry=[Point(0, 0)] * 3, crs="epsg:31370")
1037+
write_dataframe(input_gdf, path)
1038+
output_gdf = read_dataframe(path)
1039+
# We read it back as default (non-nullable) numpy dtypes, so we cast
1040+
# to those for the expected result
1041+
expected = input_gdf.copy()
1042+
expected["col2"] = expected["col2"].astype("float64")
1043+
expected["col3"] = expected["col3"].astype("float32")
1044+
expected["col4"] = expected["col4"].astype("float64")
1045+
expected["col5"] = expected["col5"].astype(object)
1046+
assert_geodataframe_equal(output_gdf, expected)

pyogrio/tests/test_raw_io.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,31 @@ def test_encoding_io_shapefile(tmp_path, read_encoding, write_encoding):
806806
assert np.array_equal(
807807
fields, read_info(filename, encoding=read_encoding)["fields"]
808808
)
809+
810+
811+
def test_write_with_mask(tmp_path):
812+
# Point(0, 0), null
813+
geometry = np.array(
814+
[bytes.fromhex("010100000000000000000000000000000000000000")] * 3,
815+
dtype=object,
816+
)
817+
field_data = [np.array([1, 2, 3], dtype="int32")]
818+
field_mask = [np.array([False, True, False])]
819+
fields = ["col"]
820+
meta = dict(geometry_type="Point", crs="EPSG:4326")
821+
822+
filename = tmp_path / "test.geojson"
823+
write(filename, geometry, field_data, fields, field_mask, **meta)
824+
result_geometry, result_fields = read(filename)[2:]
825+
assert np.array_equal(result_geometry, geometry)
826+
np.testing.assert_allclose(result_fields[0], np.array([1, np.nan, 3]))
827+
828+
# wrong length for mask
829+
field_mask = [np.array([False, True])]
830+
with pytest.raises(ValueError):
831+
write(filename, geometry, field_data, fields, field_mask, **meta)
832+
833+
# wrong number of mask arrays
834+
field_mask = [np.array([False, True, False])] * 2
835+
with pytest.raises(ValueError):
836+
write(filename, geometry, field_data, fields, field_mask, **meta)

0 commit comments

Comments
 (0)