Skip to content

Commit 806d1ae

Browse files
Implement serialization functionality for Config and VFS objects (#2110)
* Implement serialization functionality for VFS objects * Implement serialization for Config * Remove PatchedConfig and PatchedCtx --------- Co-authored-by: Theodore Tsirpanis <[email protected]>
1 parent e9d05cd commit 806d1ae

File tree

5 files changed

+69
-53
lines changed

5 files changed

+69
-53
lines changed

tiledb/ctx.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,20 @@ def save(self, uri: str):
293293
"""
294294
self.save_to_file(uri)
295295

296+
def __reduce__(self):
297+
"""
298+
Customize the pickling process by defining how to serialize
299+
and reconstruct the Config object.
300+
"""
301+
state = self.dict()
302+
return (self.__class__, (), state)
303+
304+
def __setstate__(self, state):
305+
"""
306+
Customize how the Config object is restored from its serialized state.
307+
"""
308+
self.__init__(state)
309+
296310

297311
class ConfigKeys:
298312
"""

tiledb/tests/conftest.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,54 +52,6 @@ def pytest_configure(config):
5252
# default must be set here rather than globally
5353
pytest.tiledb_vfs = "file"
5454

55-
vfs_config(config)
56-
57-
58-
def vfs_config(pytestconfig):
59-
vfs_config_override = {}
60-
61-
vfs = pytestconfig.getoption("vfs")
62-
if vfs == "s3":
63-
pytest.tiledb_vfs = "s3"
64-
65-
vfs_config_override.update(
66-
{
67-
"vfs.s3.endpoint_override": "localhost:9999",
68-
"vfs.s3.aws_access_key_id": "minio",
69-
"vfs.s3.aws_secret_access_key": "miniosecretkey",
70-
"vfs.s3.scheme": "https",
71-
"vfs.s3.verify_ssl": False,
72-
"vfs.s3.use_virtual_addressing": False,
73-
}
74-
)
75-
76-
vfs_config_arg = pytestconfig.getoption("vfs-config", None)
77-
if vfs_config_arg:
78-
pass
79-
80-
tiledb._orig_ctx = tiledb.Ctx
81-
82-
def get_config(config):
83-
final_config = {}
84-
if isinstance(config, tiledb.Config):
85-
final_config = config.dict()
86-
elif config:
87-
final_config = config
88-
89-
final_config.update(vfs_config_override)
90-
return final_config
91-
92-
class PatchedCtx(tiledb.Ctx):
93-
def __init__(self, config=None):
94-
super().__init__(get_config(config))
95-
96-
class PatchedConfig(tiledb.Config):
97-
def __init__(self, params=None):
98-
super().__init__(get_config(params))
99-
100-
tiledb.Ctx = PatchedCtx
101-
tiledb.Config = PatchedConfig
102-
10355

10456
@pytest.fixture(scope="function", autouse=True)
10557
def isolate_os_fork(original_os_fork):

tiledb/tests/test_context_and_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import io
12
import os
3+
import pickle
24
import subprocess
35
import sys
46
import xml
@@ -261,3 +263,20 @@ def test_config_repr_html(self):
261263
pytest.fail(
262264
f"Could not parse config._repr_html_(). Saw {config._repr_html_()}"
263265
)
266+
267+
def test_config_pickle(self):
268+
# test that Config can be pickled and unpickled
269+
config = tiledb.Config(
270+
{
271+
"rest.use_refactored_array_open": "false",
272+
"rest.use_refactored_array_open_and_query_submit": "true",
273+
"vfs.azure.storage_account_name": "myaccount",
274+
}
275+
)
276+
with io.BytesIO() as buf:
277+
pickle.dump(config, buf)
278+
buf.seek(0)
279+
config2 = pickle.load(buf)
280+
281+
self.assertIsInstance(config2, tiledb.Config)
282+
self.assertEqual(config2.dict(), config.dict())

tiledb/tests/test_vfs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import os
33
import pathlib
4+
import pickle
45
import random
56
import sys
67

@@ -239,6 +240,21 @@ def test_io(self):
239240
txtio = io.TextIOWrapper(f2, encoding="utf-8")
240241
self.assertEqual(txtio.readlines(), lines)
241242

243+
def test_pickle(self):
244+
# test that vfs can be pickled and unpickled with config options
245+
config = tiledb.Config(
246+
{"vfs.s3.region": "eu-west-1", "vfs.max_parallel_ops": "1"}
247+
)
248+
vfs = tiledb.VFS(config)
249+
with io.BytesIO() as buf:
250+
pickle.dump(vfs, buf)
251+
buf.seek(0)
252+
vfs2 = pickle.load(buf)
253+
254+
self.assertIsInstance(vfs2, tiledb.VFS)
255+
self.assertEqual(vfs2.config()["vfs.s3.region"], "eu-west-1")
256+
self.assertEqual(vfs2.config()["vfs.max_parallel_ops"], "1")
257+
242258
def test_sc42569_vfs_memoryview(self):
243259
# This test is to ensure that giving np.ndarray buffer to readinto works
244260
# when trying to write bytes that cannot be converted to float32 or int32

tiledb/vfs.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class VFS(lt.VFS):
2525
"""
2626

2727
def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None):
28-
ctx = ctx or default_ctx()
28+
self.ctx = ctx or default_ctx()
2929

3030
if config:
3131
from .libtiledb import Config
@@ -39,12 +39,12 @@ def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None
3939
raise ValueError("`config` argument must be of type Config or dict")
4040

4141
# Convert all values to strings
42-
config = {k: str(v) for k, v in config.items()}
42+
self.config_dict = {k: str(v) for k, v in config.items()}
4343

44-
ccfg = tiledb.Config(config)
45-
super().__init__(ctx, ccfg)
44+
ccfg = tiledb.Config(self.config_dict)
45+
super().__init__(self.ctx, ccfg)
4646
else:
47-
super().__init__(ctx)
47+
super().__init__(self.ctx)
4848

4949
def ctx(self) -> Ctx:
5050
"""
@@ -329,6 +329,21 @@ def touch(self, uri: _AnyPath):
329329
isfile = is_file
330330
size = file_size
331331

332+
# pickling support
333+
def __getstate__(self):
334+
# self.config_dict might not exist. In that case use the config from ctx.
335+
if hasattr(self, "config_dict"):
336+
config_dict = self.config_dict
337+
else:
338+
config_dict = self.config().dict()
339+
return (config_dict,)
340+
341+
def __setstate__(self, state):
342+
config_dict = state[0]
343+
config = Config(params=config_dict)
344+
ctx = Ctx(config)
345+
self.__init__(config=config, ctx=ctx)
346+
332347

333348
class FileIO(io.RawIOBase):
334349
"""TileDB FileIO class that encapsulates files opened by tiledb.VFS. The file

0 commit comments

Comments
 (0)