5
5
import logging
6
6
import time
7
7
from datetime import datetime , timezone
8
- from typing import Dict , Optional , Sequence , Tuple
8
+ from typing import Callable , Dict , Optional , Sequence , Tuple
9
9
10
10
import clickhouse_driver
11
11
import numpy
@@ -215,9 +215,14 @@ class ClickHouseRun(Run):
215
215
"""Represents an MCMC run stored in ClickHouse."""
216
216
217
217
def __init__ (
218
- self , meta : RunMeta , * , created_at : datetime = None , client : clickhouse_driver .Client
218
+ self ,
219
+ meta : RunMeta ,
220
+ * ,
221
+ created_at : datetime = None ,
222
+ client_fn : Callable [[], clickhouse_driver .Client ],
219
223
) -> None :
220
- self ._client = client
224
+ self ._client_fn = client_fn
225
+ self ._client = client_fn ()
221
226
if created_at is None :
222
227
created_at = datetime .now ().astimezone (timezone .utc )
223
228
self .created_at = created_at
@@ -229,7 +234,7 @@ def __init__(
229
234
def init_chain (self , chain_number : int ) -> ClickHouseChain :
230
235
cmeta = ChainMeta (self .meta .rid , chain_number )
231
236
create_chain_table (self ._client , cmeta , self .meta )
232
- chain = ClickHouseChain (cmeta , self .meta , client = self ._client )
237
+ chain = ClickHouseChain (cmeta , self .meta , client = self ._client_fn () )
233
238
if self ._chains is None :
234
239
self ._chains = []
235
240
self ._chains .append (chain )
@@ -245,16 +250,39 @@ def get_chains(self) -> Tuple[ClickHouseChain]:
245
250
chains = []
246
251
for (cid ,) in self ._client .execute (f"SHOW TABLES LIKE '{ self .meta .rid } %'" ):
247
252
cm = ChainMeta (self .meta .rid , int (cid .split ("_" )[- 1 ]))
248
- chains .append (ClickHouseChain (cm , self .meta , client = self ._client ))
253
+ chains .append (ClickHouseChain (cm , self .meta , client = self ._client_fn () ))
249
254
return tuple (chains )
250
255
251
256
252
257
class ClickHouseBackend (Backend ):
253
258
"""A backend to store samples in a ClickHouse database."""
254
259
255
- def __init__ (self , client : clickhouse_driver .Client ) -> None :
260
+ def __init__ (
261
+ self ,
262
+ client : clickhouse_driver .Client = None ,
263
+ client_fn : Callable [[], clickhouse_driver .Client ] = None ,
264
+ ):
265
+ """Create a ClickHouse backend around a database client.
266
+
267
+ Parameters
268
+ ----------
269
+ client : clickhouse_driver.Client
270
+ One client to use for all runs and chains.
271
+ client_fn : callable
272
+ A function to create database clients.
273
+ Use this in multithreading scenarios to get higher insert performance.
274
+ """
275
+ if client is None and client_fn is None :
276
+ raise ValueError ("Either a `client` or a `client_fn` must be provided." )
277
+ self ._client_fn = client_fn
256
278
self ._client = client
257
- create_runs_table (client )
279
+
280
+ if client_fn is None :
281
+ self ._client_fn = lambda : client
282
+ if client is None :
283
+ self ._client = self ._client_fn ()
284
+
285
+ create_runs_table (self ._client )
258
286
super ().__init__ ()
259
287
260
288
def init_run (self , meta : RunMeta ) -> ClickHouseRun :
@@ -271,7 +299,7 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun:
271
299
proto = base64 .encodebytes (bytes (meta )).decode ("ascii" ),
272
300
)
273
301
self ._client .execute (query , [params ])
274
- return ClickHouseRun (meta , client = self ._client , created_at = created_at )
302
+ return ClickHouseRun (meta , client_fn = self ._client_fn , created_at = created_at )
275
303
276
304
def get_runs (self ) -> pandas .DataFrame :
277
305
df = self ._client .query_dataframe (
@@ -295,5 +323,5 @@ def get_run(self, rid: str) -> ClickHouseRun:
295
323
data = base64 .decodebytes (rows [0 ][2 ].encode ("ascii" ))
296
324
meta = RunMeta ().parse (data )
297
325
return ClickHouseRun (
298
- meta , client = self ._client , created_at = rows [0 ][1 ].replace (tzinfo = timezone .utc )
326
+ meta , client_fn = self ._client_fn , created_at = rows [0 ][1 ].replace (tzinfo = timezone .utc )
299
327
)
0 commit comments