Skip to content

Commit 4bab496

Browse files
authored
ENH: avoid redundant processing in write_dataframe with use_arrow (#518)
1 parent 09f401a commit 4bab496

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

pyogrio/geopandas.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def read_dataframe(
333333
return gp.GeoDataFrame(df, geometry=geometry, crs=meta["crs"])
334334

335335

336-
# TODO: handle index properly
337336
def write_dataframe(
338337
df,
339338
path,
@@ -469,47 +468,9 @@ def write_dataframe(
469468
if len(geometry_columns) > 0:
470469
geometry_column = geometry_columns[0]
471470
geometry = df[geometry_column]
472-
fields = [c for c in df.columns if not c == geometry_column]
473471
else:
474472
geometry_column = None
475473
geometry = None
476-
fields = list(df.columns)
477-
478-
# TODO: may need to fill in pd.NA, etc
479-
field_data = []
480-
field_mask = []
481-
# dict[str, np.array(int)] special case for dt-tz fields
482-
gdal_tz_offsets = {}
483-
for name in fields:
484-
col = df[name]
485-
if isinstance(col.dtype, pd.DatetimeTZDtype):
486-
# Deal with datetimes with timezones by passing down timezone separately
487-
# pass down naive datetime
488-
naive = col.dt.tz_localize(None)
489-
values = naive.values
490-
# compute offset relative to UTC explicitly
491-
tz_offset = naive - col.dt.tz_convert("UTC").dt.tz_localize(None)
492-
# Convert to GDAL timezone offset representation.
493-
# GMT is represented as 100 and offsets are represented by adding /
494-
# subtracting 1 for every 15 minutes different from GMT.
495-
# https://gdal.org/development/rfc/rfc56_millisecond_precision.html#core-changes
496-
# Convert each row offset to a signed multiple of 15m and add to GMT value
497-
gdal_offset_representation = tz_offset // pd.Timedelta("15m") + 100
498-
gdal_tz_offsets[name] = gdal_offset_representation.values
499-
else:
500-
values = col.values
501-
if isinstance(values, pd.api.extensions.ExtensionArray):
502-
from pandas.arrays import BooleanArray, FloatingArray, IntegerArray
503-
504-
if isinstance(values, (IntegerArray, FloatingArray, BooleanArray)):
505-
field_data.append(values._data)
506-
field_mask.append(values._mask)
507-
else:
508-
field_data.append(np.asarray(values))
509-
field_mask.append(np.asarray(values.isna()))
510-
else:
511-
field_data.append(values)
512-
field_mask.append(None)
513474

514475
# Determine geometry_type and/or promote_to_multi
515476
if geometry_column is not None:
@@ -658,6 +619,46 @@ def write_dataframe(
658619
# If there is geometry data, prepare it to be written
659620
if geometry_column is not None:
660621
geometry = to_wkb(geometry.values)
622+
fields = [c for c in df.columns if not c == geometry_column]
623+
else:
624+
fields = list(df.columns)
625+
626+
# Convert data to numpy arrays for writing
627+
# TODO: may need to fill in pd.NA, etc
628+
field_data = []
629+
field_mask = []
630+
# dict[str, np.array(int)] special case for dt-tz fields
631+
gdal_tz_offsets = {}
632+
for name in fields:
633+
col = df[name]
634+
if isinstance(col.dtype, pd.DatetimeTZDtype):
635+
# Deal with datetimes with timezones by passing down timezone separately
636+
# pass down naive datetime
637+
naive = col.dt.tz_localize(None)
638+
values = naive.values
639+
# compute offset relative to UTC explicitly
640+
tz_offset = naive - col.dt.tz_convert("UTC").dt.tz_localize(None)
641+
# Convert to GDAL timezone offset representation.
642+
# GMT is represented as 100 and offsets are represented by adding /
643+
# subtracting 1 for every 15 minutes different from GMT.
644+
# https://gdal.org/development/rfc/rfc56_millisecond_precision.html#core-changes
645+
# Convert each row offset to a signed multiple of 15m and add to GMT value
646+
gdal_offset_representation = tz_offset // pd.Timedelta("15m") + 100
647+
gdal_tz_offsets[name] = gdal_offset_representation.values
648+
else:
649+
values = col.values
650+
if isinstance(values, pd.api.extensions.ExtensionArray):
651+
from pandas.arrays import BooleanArray, FloatingArray, IntegerArray
652+
653+
if isinstance(values, (IntegerArray, FloatingArray, BooleanArray)):
654+
field_data.append(values._data)
655+
field_mask.append(values._mask)
656+
else:
657+
field_data.append(np.asarray(values))
658+
field_mask.append(np.asarray(values.isna()))
659+
else:
660+
field_data.append(values)
661+
field_mask.append(None)
661662

662663
write(
663664
path,

0 commit comments

Comments
 (0)