Skip to content

Commit a99f3f6

Browse files
Fix reading from zstd decompression stream (#443)
1 parent 6a2b0d8 commit a99f3f6

File tree

6 files changed

+61
-1
lines changed

6 files changed

+61
-1
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
- Silence warning from `write_dataframe` with `GeoSeries.notna()` (#435).
88
- BUG: Enable mask & bbox filter when geometry column not read (#431).
9+
- Prevent seek on read from compressed inputs (#443).
910

1011
## 0.9.0 (2024-06-17)
1112

pyogrio/tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from io import BytesIO
12
from pathlib import Path
23
from zipfile import ZIP_DEFLATED, ZipFile
34

@@ -178,6 +179,31 @@ def geojson_filelike(tmp_path):
178179
yield f
179180

180181

182+
@pytest.fixture(scope="function")
183+
def nonseekable_bytes(tmp_path):
184+
# mock a non-seekable byte stream, such as a zstandard handle
185+
class NonSeekableBytesIO(BytesIO):
186+
def seekable(self):
187+
return False
188+
189+
def seek(self, *args, **kwargs):
190+
raise OSError("cannot seek")
191+
192+
# wrap GeoJSON into a non-seekable BytesIO
193+
geojson = """{
194+
"type": "FeatureCollection",
195+
"features": [
196+
{
197+
"type": "Feature",
198+
"properties": { },
199+
"geometry": { "type": "Point", "coordinates": [1, 1] }
200+
}
201+
]
202+
}"""
203+
204+
return NonSeekableBytesIO(geojson.encode("UTF-8"))
205+
206+
181207
@pytest.fixture(
182208
scope="session",
183209
params=[

pyogrio/tests/test_arrow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ def test_read_arrow_bytes(geojson_bytes):
168168
assert len(table) == 3
169169

170170

171+
def test_read_arrow_nonseekable_bytes(nonseekable_bytes):
172+
meta, table = read_arrow(nonseekable_bytes)
173+
assert meta["fields"].shape == (0,)
174+
assert len(table) == 1
175+
176+
171177
def test_read_arrow_filelike(geojson_filelike):
172178
meta, table = read_arrow(geojson_filelike)
173179

pyogrio/tests/test_core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def test_list_layers_bytes(geojson_bytes):
184184
assert layers[0, 0] == "test"
185185

186186

187+
def test_list_layers_nonseekable_bytes(nonseekable_bytes):
188+
layers = list_layers(nonseekable_bytes)
189+
190+
assert layers.shape == (1, 2)
191+
assert layers[0, 1] == "Point"
192+
193+
187194
def test_list_layers_filelike(geojson_filelike):
188195
layers = list_layers(geojson_filelike)
189196

@@ -218,6 +225,13 @@ def test_read_bounds_bytes(geojson_bytes):
218225
assert allclose(bounds[:, 0], [-180.0, -18.28799, 180.0, -16.02088])
219226

220227

228+
def test_read_bounds_nonseekable_bytes(nonseekable_bytes):
229+
fids, bounds = read_bounds(nonseekable_bytes)
230+
assert fids.shape == (1,)
231+
assert bounds.shape == (4, 1)
232+
assert allclose(bounds[:, 0], [1, 1, 1, 1])
233+
234+
221235
def test_read_bounds_filelike(geojson_filelike):
222236
fids, bounds = read_bounds(geojson_filelike)
223237
assert fids.shape == (3,)
@@ -449,6 +463,13 @@ def test_read_info_bytes(geojson_bytes):
449463
assert meta["features"] == 3
450464

451465

466+
def test_read_info_nonseekable_bytes(nonseekable_bytes):
467+
meta = read_info(nonseekable_bytes)
468+
469+
assert meta["fields"].shape == (0,)
470+
assert meta["features"] == 1
471+
472+
452473
def test_read_info_filelike(geojson_filelike):
453474
meta = read_info(geojson_filelike)
454475

pyogrio/tests/test_raw_io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,12 @@ def test_read_from_file_like(tmp_path, naturalearth_lowres, driver, ext):
819819
assert_equal_result((meta, index, geometry, field_data), result2)
820820

821821

822+
def test_read_from_nonseekable_bytes(nonseekable_bytes):
823+
meta, _, geometry, _ = read(nonseekable_bytes)
824+
assert meta["fields"].shape == (0,)
825+
assert len(geometry) == 1
826+
827+
822828
@pytest.mark.parametrize("ext", ["gpkg", "fgb"])
823829
def test_read_write_data_types_numeric(tmp_path, ext):
824830
# Point(0, 0)

pyogrio/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_vsi_path_or_buffer(path_or_buffer):
3838
bytes_buffer = path_or_buffer.read()
3939

4040
# rewind buffer if possible so that subsequent operations do not need to rewind
41-
if hasattr(path_or_buffer, "seek"):
41+
if hasattr(path_or_buffer, "seekable") and path_or_buffer.seekable():
4242
path_or_buffer.seek(0)
4343

4444
return bytes_buffer

0 commit comments

Comments
 (0)