Skip to content

Commit 0550bc7

Browse files
authored
Streaming Arrow reader (#206)
1 parent 61c9fa3 commit 0550bc7

File tree

5 files changed

+154
-13
lines changed

5 files changed

+154
-13
lines changed

CHANGES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
specifying a mask manually for missing values in `write` (#219)
1313
- Standardized 3-dimensional geometry type labels from "2.5D <type>" to
1414
"<type> Z" for consistency with well-known text (WKT) formats (#234)
15+
- Failure error messages from GDAL are no longer printed to stderr (they were
16+
already translated into Python exceptions as well) (#236).
1517
- Failure and warning error messages from GDAL are no longer printed to
1618
stderr: failures were already translated into Python exceptions
1719
and warning messages are now translated into Python warnings (#236, #242).
20+
- Add access to low-level pyarrow `RecordBatchReader` via
21+
`pyogrio.raw.open_arrow`, which allows iterating over batches of Arrow
22+
tables (#205).
1823

1924
### Packaging
2025

docs/source/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ GeoPandas integration
1212

1313
.. autofunction:: pyogrio.read_dataframe
1414
.. autofunction:: pyogrio.write_dataframe
15+
16+
Reading as Arrow data
17+
---------------------
18+
19+
.. autofunction:: pyogrio.raw.read_arrow
20+
.. autofunction:: pyogrio.raw.open_arrow

pyogrio/_io.pyx

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66

7+
import contextlib
78
import datetime
89
import locale
910
import logging
@@ -1027,8 +1028,8 @@ def ogr_read(
10271028
field_data
10281029
)
10291030

1030-
1031-
def ogr_read_arrow(
1031+
@contextlib.contextmanager
1032+
def ogr_open_arrow(
10321033
str path,
10331034
dataset_kwargs,
10341035
object layer=None,
@@ -1043,7 +1044,8 @@ def ogr_read_arrow(
10431044
object fids=None,
10441045
str sql=None,
10451046
str sql_dialect=None,
1046-
int return_fids=False):
1047+
int return_fids=False,
1048+
int batch_size=0):
10471049

10481050
cdef int err = 0
10491051
cdef const char *path_c = NULL
@@ -1074,6 +1076,7 @@ def ogr_read_arrow(
10741076
if sql is not None and layer is not None:
10751077
raise ValueError("'sql' paramater cannot be combined with 'layer'")
10761078

1079+
reader = None
10771080
try:
10781081
dataset_options = dict_to_options(dataset_kwargs)
10791082
ogr_dataset = ogr_open(path_c, 0, dataset_options)
@@ -1129,6 +1132,13 @@ def ogr_read_arrow(
11291132
if not return_fids:
11301133
options = CSLSetNameValue(options, "INCLUDE_FID", "NO")
11311134

1135+
if batch_size > 0:
1136+
options = CSLSetNameValue(
1137+
options,
1138+
"MAX_FEATURES_IN_BATCH",
1139+
str(batch_size).encode('UTF-8')
1140+
)
1141+
11321142
# make sure layer is read from beginning
11331143
OGR_L_ResetReading(ogr_layer)
11341144

@@ -1142,7 +1152,7 @@ def ogr_read_arrow(
11421152

11431153
# stream has to be consumed before the Dataset is closed
11441154
import pyarrow as pa
1145-
table = pa.RecordBatchStreamReader._import_from_c(stream_ptr).read_all()
1155+
reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr)
11461156

11471157
meta = {
11481158
'crs': crs,
@@ -1152,7 +1162,12 @@ def ogr_read_arrow(
11521162
'geometry_name': geometry_name,
11531163
}
11541164

1165+
yield meta, reader
1166+
11551167
finally:
1168+
if reader is not None:
1169+
# Mark reader as closed to prevent reading batches
1170+
reader.close()
11561171

11571172
CSLDestroy(options)
11581173
if fields_c != NULL:
@@ -1170,9 +1185,6 @@ def ogr_read_arrow(
11701185
GDALClose(ogr_dataset)
11711186
ogr_dataset = NULL
11721187

1173-
return meta, table
1174-
1175-
11761188
def ogr_read_bounds(
11771189
str path,
11781190
object layer=None,

pyogrio/raw.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pyogrio.util import get_vsi_path
77

88
with GDALEnv():
9-
from pyogrio._io import ogr_read, ogr_read_arrow, ogr_write
9+
from pyogrio._io import ogr_open_arrow, ogr_read, ogr_write
1010
from pyogrio._ogr import (
1111
get_gdal_version,
1212
get_gdal_version_string,
@@ -179,7 +179,100 @@ def read_arrow(
179179
"""
180180
Read OGR data source into a pyarrow Table.
181181
182-
See docstring of `read` for details.
182+
See docstring of `read` for parameters.
183+
184+
Returns
185+
-------
186+
(dict, pyarrow.Table)
187+
188+
Returns a tuple of meta information about the data source in a dict,
189+
and a pyarrow Table with data.
190+
191+
Meta is: {
192+
"crs": "<crs>",
193+
"fields": <ndarray of field names>,
194+
"encoding": "<encoding>",
195+
"geometry_type": "<geometry_type>",
196+
"geometry_name": "<name of geometry column in arrow table>",
197+
}
198+
"""
199+
with open_arrow(
200+
path_or_buffer,
201+
layer=layer,
202+
encoding=encoding,
203+
columns=columns,
204+
read_geometry=read_geometry,
205+
force_2d=force_2d,
206+
skip_features=skip_features,
207+
max_features=max_features,
208+
where=where,
209+
bbox=bbox,
210+
fids=fids,
211+
sql=sql,
212+
sql_dialect=sql_dialect,
213+
return_fids=return_fids,
214+
**kwargs,
215+
) as source:
216+
meta, reader = source
217+
table = reader.read_all()
218+
219+
return meta, table
220+
221+
222+
def open_arrow(
223+
path_or_buffer,
224+
/,
225+
layer=None,
226+
encoding=None,
227+
columns=None,
228+
read_geometry=True,
229+
force_2d=False,
230+
skip_features=0,
231+
max_features=None,
232+
where=None,
233+
bbox=None,
234+
fids=None,
235+
sql=None,
236+
sql_dialect=None,
237+
return_fids=False,
238+
batch_size=65_536,
239+
**kwargs,
240+
):
241+
"""
242+
Open OGR data source as a stream of pyarrow record batches.
243+
244+
See docstring of `read` for parameters.
245+
246+
The RecordBatchStreamReader is reading from a stream provided by OGR and must not be
247+
accessed after the OGR dataset has been closed, i.e. after the context manager has
248+
been closed.
249+
250+
Examples
251+
--------
252+
253+
>>> from pyogrio.raw import open_arrow
254+
>>> import pyarrow as pa
255+
>>> import shapely
256+
>>>
257+
>>> with open_arrow(path) as source:
258+
>>> meta, reader = source
259+
>>> for table in reader:
260+
>>> geometries = shapely.from_wkb(table[meta["geometry_name"]])
261+
262+
Returns
263+
-------
264+
(dict, pyarrow.RecordBatchStreamReader)
265+
266+
Returns a tuple of meta information about the data source in a dict,
267+
and a pyarrow RecordBatchStreamReader with data.
268+
269+
Meta is: {
270+
"crs": "<crs>",
271+
"fields": <ndarray of field names>,
272+
"encoding": "<encoding>",
273+
"geometry_type": "<geometry_type>",
274+
"geometry_name": "<name of geometry column in arrow table>",
275+
}
183276
"""
184277
try:
185278
import pyarrow # noqa
@@ -191,7 +284,7 @@ def read_arrow(
191284
dataset_kwargs = _preprocess_options_key_value(kwargs) if kwargs else {}
192285

193286
try:
194-
result = ogr_read_arrow(
287+
return ogr_open_arrow(
195288
path,
196289
layer=layer,
197290
encoding=encoding,
@@ -207,13 +300,12 @@ def read_arrow(
207300
sql_dialect=sql_dialect,
208301
return_fids=return_fids,
209302
dataset_kwargs=dataset_kwargs,
303+
batch_size=batch_size,
210304
)
211305
finally:
212306
if buffer is not None:
213307
remove_virtual_file(path)
214308

215-
return result
216-
217309

218310
def detect_driver(path):
219311
# try to infer driver from path

pyogrio/tests/test_arrow.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import math
2+
13
import pytest
24

35
from pyogrio import __gdal_version__, read_dataframe
4-
from pyogrio.raw import read_arrow
6+
from pyogrio.raw import open_arrow, read_arrow
57

68
try:
79
import pandas as pd
@@ -74,3 +76,27 @@ def test_read_arrow_raw(naturalearth_lowres):
7476
meta, table = read_arrow(naturalearth_lowres)
7577
assert isinstance(meta, dict)
7678
assert isinstance(table, pyarrow.Table)
79+
80+
81+
def test_open_arrow(naturalearth_lowres):
82+
with open_arrow(naturalearth_lowres) as (meta, reader):
83+
assert isinstance(meta, dict)
84+
assert isinstance(reader, pyarrow.RecordBatchReader)
85+
assert isinstance(reader.read_all(), pyarrow.Table)
86+
87+
88+
def test_open_arrow_batch_size(naturalearth_lowres):
89+
meta, table = read_arrow(naturalearth_lowres)
90+
batch_size = math.ceil(len(table) / 2)
91+
92+
with open_arrow(naturalearth_lowres, batch_size=batch_size) as (meta, reader):
93+
assert isinstance(meta, dict)
94+
assert isinstance(reader, pyarrow.RecordBatchReader)
95+
count = 0
96+
tables = []
97+
for table in reader:
98+
tables.append(table)
99+
count += 1
100+
101+
assert count == 2, "Should be two batches given the batch_size parameter"
102+
assert len(tables[0]) == batch_size, "First table should match the batch size"

0 commit comments

Comments
 (0)