Skip to content

Commit 1016b19

Browse files
Implement __getstate__ and __setstate__ so that FileIO instances can be pickled (#543)
1 parent 4148edb commit 1016b19

File tree

4 files changed

+75
-0
lines changed

4 files changed

+75
-0
lines changed

pyiceberg/io/fsspec.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import logging
2222
import os
23+
from copy import copy
2324
from functools import lru_cache, partial
2425
from typing import (
2526
Any,
@@ -338,3 +339,14 @@ def _get_fs(self, scheme: str) -> AbstractFileSystem:
338339
if scheme not in self._scheme_to_fs:
339340
raise ValueError(f"No registered filesystem for scheme: {scheme}")
340341
return self._scheme_to_fs[scheme](self.properties)
342+
343+
def __getstate__(self) -> Dict[str, Any]:
344+
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
345+
fileio_copy = copy(self.__dict__)
346+
fileio_copy["get_fs"] = None
347+
return fileio_copy
348+
349+
def __setstate__(self, state: Dict[str, Any]) -> None:
350+
"""Deserialize the state into a FsSpecFileIO instance."""
351+
self.__dict__ = state
352+
self.get_fs = lru_cache(self._get_fs)

pyiceberg/io/pyarrow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import re
3434
from abc import ABC, abstractmethod
3535
from concurrent.futures import Future
36+
from copy import copy
3637
from dataclasses import dataclass
3738
from enum import Enum
3839
from functools import lru_cache, singledispatch
@@ -456,6 +457,17 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
456457
raise PermissionError(f"Cannot delete file, access denied: {location}") from e
457458
raise # pragma: no cover - If some other kind of OSError, raise the raw error
458459

460+
def __getstate__(self) -> Dict[str, Any]:
461+
"""Create a dictionary of the PyArrowFileIO fields used when pickling."""
462+
fileio_copy = copy(self.__dict__)
463+
fileio_copy["fs_by_scheme"] = None
464+
return fileio_copy
465+
466+
def __setstate__(self, state: Dict[str, Any]) -> None:
467+
"""Deserialize the state into a PyArrowFileIO instance."""
468+
self.__dict__ = state
469+
self.fs_by_scheme = lru_cache(self._initialize_fs)
470+
459471

460472
def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
461473
return visit(schema, _ConvertToArrowSchema(metadata))

tests/io/test_fsspec.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import os
19+
import pickle
1920
import tempfile
2021
import uuid
2122

@@ -229,6 +230,11 @@ def test_writing_avro_file(generated_manifest_entry_file: str, fsspec_fileio: Fs
229230
fsspec_fileio.delete(f"s3://warehouse/{filename}")
230231

231232

233+
@pytest.mark.s3
234+
def test_fsspec_pickle_round_trip_s3(fsspec_fileio: FsspecFileIO) -> None:
235+
_test_fsspec_pickle_round_trip(fsspec_fileio, "s3://warehouse/foo.txt")
236+
237+
232238
@pytest.mark.adlfs
233239
def test_fsspec_new_input_file_adlfs(adlfs_fsspec_fileio: FsspecFileIO) -> None:
234240
"""Test creating a new input file from an fsspec file-io"""
@@ -410,6 +416,11 @@ def test_writing_avro_file_adlfs(generated_manifest_entry_file: str, adlfs_fsspe
410416
adlfs_fsspec_fileio.delete(f"abfss://tests/{filename}")
411417

412418

419+
@pytest.mark.adlfs
420+
def test_fsspec_pickle_round_trip_aldfs(adlfs_fsspec_fileio: FsspecFileIO) -> None:
421+
_test_fsspec_pickle_round_trip(adlfs_fsspec_fileio, "abfss://tests/foo.txt")
422+
423+
413424
@pytest.mark.gcs
414425
def test_fsspec_new_input_file_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
415426
"""Test creating a new input file from a fsspec file-io"""
@@ -586,6 +597,26 @@ def test_writing_avro_file_gcs(generated_manifest_entry_file: str, fsspec_fileio
586597
fsspec_fileio_gcs.delete(f"gs://warehouse/{filename}")
587598

588599

600+
@pytest.mark.gcs
601+
def test_fsspec_pickle_roundtrip_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
602+
_test_fsspec_pickle_round_trip(fsspec_fileio_gcs, "gs://warehouse/foo.txt")
603+
604+
605+
def _test_fsspec_pickle_round_trip(fsspec_fileio: FsspecFileIO, location: str) -> None:
606+
serialized_file_io = pickle.dumps(fsspec_fileio)
607+
deserialized_file_io = pickle.loads(serialized_file_io)
608+
output_file = deserialized_file_io.new_output(location)
609+
with output_file.create() as f:
610+
f.write(b"foo")
611+
612+
input_file = deserialized_file_io.new_input(location)
613+
with input_file.open() as f:
614+
data = f.read()
615+
assert data == b"foo"
616+
assert len(input_file) == 3
617+
deserialized_file_io.delete(location)
618+
619+
589620
TEST_URI = "https://iceberg-test-signer"
590621

591622

tests/io/test_io.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import os
19+
import pickle
1920
import tempfile
2021

2122
import pytest
@@ -71,6 +72,25 @@ def test_custom_local_output_file() -> None:
7172
assert len(output_file) == 3
7273

7374

75+
def test_pickled_pyarrow_round_trip() -> None:
76+
with tempfile.TemporaryDirectory() as tmpdirname:
77+
file_location = os.path.join(tmpdirname, "foo.txt")
78+
file_io = PyArrowFileIO()
79+
serialized_file_io = pickle.dumps(file_io)
80+
deserialized_file_io = pickle.loads(serialized_file_io)
81+
absolute_file_location = os.path.abspath(file_location)
82+
output_file = deserialized_file_io.new_output(location=f"{absolute_file_location}")
83+
with output_file.create() as f:
84+
f.write(b"foo")
85+
86+
input_file = deserialized_file_io.new_input(location=f"{absolute_file_location}")
87+
f = input_file.open()
88+
data = f.read()
89+
assert data == b"foo"
90+
assert len(input_file) == 3
91+
deserialized_file_io.delete(location=f"{absolute_file_location}")
92+
93+
7494
def test_custom_local_output_file_with_overwrite() -> None:
7595
"""Test initializing an OutputFile implementation to overwrite a local file"""
7696
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)