Skip to content

Commit a21a90c

Browse files
Add ClickHouseBackend(client_fn=...) option to keep chains separate
Closes #14
1 parent 291c489 commit a21a90c

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

mcbackend/backends/clickhouse.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from datetime import datetime, timezone
8-
from typing import Dict, Optional, Sequence, Tuple
8+
from typing import Callable, Dict, Optional, Sequence, Tuple
99

1010
import clickhouse_driver
1111
import numpy
@@ -215,9 +215,14 @@ class ClickHouseRun(Run):
215215
"""Represents an MCMC run stored in ClickHouse."""
216216

217217
def __init__(
218-
self, meta: RunMeta, *, created_at: datetime = None, client: clickhouse_driver.Client
218+
self,
219+
meta: RunMeta,
220+
*,
221+
created_at: datetime = None,
222+
client_fn: Callable[[], clickhouse_driver.Client],
219223
) -> None:
220-
self._client = client
224+
self._client_fn = client_fn
225+
self._client = client_fn()
221226
if created_at is None:
222227
created_at = datetime.now().astimezone(timezone.utc)
223228
self.created_at = created_at
@@ -229,7 +234,7 @@ def __init__(
229234
def init_chain(self, chain_number: int) -> ClickHouseChain:
230235
cmeta = ChainMeta(self.meta.rid, chain_number)
231236
create_chain_table(self._client, cmeta, self.meta)
232-
chain = ClickHouseChain(cmeta, self.meta, client=self._client)
237+
chain = ClickHouseChain(cmeta, self.meta, client=self._client_fn())
233238
if self._chains is None:
234239
self._chains = []
235240
self._chains.append(chain)
@@ -245,16 +250,39 @@ def get_chains(self) -> Tuple[ClickHouseChain]:
245250
chains = []
246251
for (cid,) in self._client.execute(f"SHOW TABLES LIKE '{self.meta.rid}%'"):
247252
cm = ChainMeta(self.meta.rid, int(cid.split("_")[-1]))
248-
chains.append(ClickHouseChain(cm, self.meta, client=self._client))
253+
chains.append(ClickHouseChain(cm, self.meta, client=self._client_fn()))
249254
return tuple(chains)
250255

251256

252257
class ClickHouseBackend(Backend):
253258
"""A backend to store samples in a ClickHouse database."""
254259

255-
def __init__(self, client: clickhouse_driver.Client) -> None:
260+
def __init__(
261+
self,
262+
client: clickhouse_driver.Client = None,
263+
client_fn: Callable[[], clickhouse_driver.Client] = None,
264+
):
265+
"""Create a ClickHouse backend around a database client.
266+
267+
Parameters
268+
----------
269+
client : clickhouse_driver.Client
270+
One client to use for all runs and chains.
271+
client_fn : callable
272+
A function to create database clients.
273+
Use this in multithreading scenarios to get higher insert performance.
274+
"""
275+
if client is None and client_fn is None:
276+
raise ValueError("Either a `client` or a `client_fn` must be provided.")
277+
self._client_fn = client_fn
256278
self._client = client
257-
create_runs_table(client)
279+
280+
if client_fn is None:
281+
self._client_fn = lambda: client
282+
if client is None:
283+
self._client = self._client_fn()
284+
285+
create_runs_table(self._client)
258286
super().__init__()
259287

260288
def init_run(self, meta: RunMeta) -> ClickHouseRun:
@@ -271,7 +299,7 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun:
271299
proto=base64.encodebytes(bytes(meta)).decode("ascii"),
272300
)
273301
self._client.execute(query, [params])
274-
return ClickHouseRun(meta, client=self._client, created_at=created_at)
302+
return ClickHouseRun(meta, client_fn=self._client_fn, created_at=created_at)
275303

276304
def get_runs(self) -> pandas.DataFrame:
277305
df = self._client.query_dataframe(
@@ -295,5 +323,5 @@ def get_run(self, rid: str) -> ClickHouseRun:
295323
data = base64.decodebytes(rows[0][2].encode("ascii"))
296324
meta = RunMeta().parse(data)
297325
return ClickHouseRun(
298-
meta, client=self._client, created_at=rows[0][1].replace(tzinfo=timezone.utc)
326+
meta, client_fn=self._client_fn, created_at=rows[0][1].replace(tzinfo=timezone.utc)
299327
)

mcbackend/test_backend_clickhouse.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from subprocess import call
23
from typing import Sequence, Tuple
34

45
import clickhouse_driver
@@ -37,6 +38,63 @@ def fully_initialized(
3738
return run, chains
3839

3940

41+
@pytest.mark.skipif(
42+
condition=not HAS_REAL_DB,
43+
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
44+
)
45+
class TestClickHouseBackendInitialization:
46+
"""This is separate because ``TestClickHouseBackend.setup_method`` depends on these things."""
47+
48+
def test_exceptions(self):
49+
with pytest.raises(ValueError, match="must be provided"):
50+
ClickHouseBackend()
51+
pass
52+
53+
def test_backend_from_client_object(self):
54+
db = "testing_" + hagelkorn.random()
55+
_client_main = clickhouse_driver.Client("localhost")
56+
_client_main.execute(f"CREATE DATABASE {db};")
57+
58+
try:
59+
# When created from a client object, all chains share the client
60+
backend = ClickHouseBackend(client=clickhouse_driver.Client("localhost", database=db))
61+
assert callable(backend._client_fn)
62+
run = backend.init_run(make_runmeta())
63+
c1 = run.init_chain(0)
64+
c2 = run.init_chain(1)
65+
assert c1._client is c2._client
66+
finally:
67+
_client_main.execute(f"DROP DATABASE {db};")
68+
_client_main.disconnect()
69+
pass
70+
71+
def test_backend_from_client_function(self):
72+
db = "testing_" + hagelkorn.random()
73+
_client_main = clickhouse_driver.Client("localhost")
74+
_client_main.execute(f"CREATE DATABASE {db};")
75+
76+
def client_fn():
77+
return clickhouse_driver.Client("localhost", database=db)
78+
79+
try:
80+
# When created from a client function, each chain has its own client
81+
backend = ClickHouseBackend(client_fn=client_fn)
82+
assert backend._client is not None
83+
run = backend.init_run(make_runmeta())
84+
c1 = run.init_chain(0)
85+
c2 = run.init_chain(1)
86+
assert c1._client is not c2._client
87+
88+
# By passing both, one may use different settings
89+
bclient = client_fn()
90+
backend = ClickHouseBackend(client=bclient, client_fn=client_fn)
91+
assert backend._client is bclient
92+
finally:
93+
_client_main.execute(f"DROP DATABASE {db};")
94+
_client_main.disconnect()
95+
pass
96+
97+
4098
@pytest.mark.skipif(
4199
condition=not HAS_REAL_DB,
42100
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
@@ -52,7 +110,9 @@ def setup_method(self, method):
52110
self._client_main = clickhouse_driver.Client("localhost")
53111
self._client_main.execute(f"CREATE DATABASE {self._db};")
54112
self._client = clickhouse_driver.Client("localhost", database=self._db)
55-
self.backend = ClickHouseBackend(self._client)
113+
self.backend = ClickHouseBackend(
114+
client_fn=lambda: clickhouse_driver.Client("localhost", database=self._db)
115+
)
56116
return
57117

58118
def teardown_method(self, method):

0 commit comments

Comments
 (0)