55import logging
66import time
77from datetime import datetime , timezone
8- from typing import Callable , Dict , Optional , Sequence , Tuple
8+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple
99
1010import clickhouse_driver
1111import numpy
@@ -80,13 +80,13 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
8080 )
8181 columns .append (column_spec_for (var , is_stat = True ))
8282 assert len (set (columns )) == len (columns ), columns
83- columns = ",\n " .join (columns )
83+ cols = ",\n " .join (columns )
8484
8585 query = f"""
8686 CREATE TABLE { cid }
8787 (
8888 `_draw_idx` UInt64,
89- { columns }
89+ { cols }
9090 )
9191 ENGINE TinyLog();
9292 """
@@ -110,8 +110,8 @@ def __init__(
110110 self ._client = client
111111 # The following attributes belong to the batched insert mechanism.
112112 # Inserting in batches is much faster than inserting single rows.
113- self ._insert_query = None
114- self ._insert_queue = []
113+ self ._insert_query : str = ""
114+ self ._insert_queue : List [ Dict [ str , Any ]] = []
115115 self ._last_insert = time .time ()
116116 self ._insert_interval = insert_interval
117117 self ._insert_every = insert_every
@@ -121,7 +121,7 @@ def append(
121121 self , draw : Dict [str , numpy .ndarray ], stats : Optional [Dict [str , numpy .ndarray ]] = None
122122 ):
123123 stat = {f"__stat_{ sname } " : svals for sname , svals in (stats or {}).items ()}
124- params = {"_draw_idx" : self ._draw_idx , ** draw , ** stat }
124+ params : Dict [ str , Any ] = {"_draw_idx" : self ._draw_idx , ** draw , ** stat }
125125 self ._draw_idx += 1
126126 if not self ._insert_query :
127127 names = ", " .join (params .keys ())
@@ -186,9 +186,10 @@ def _get_rows( # pylint: disable=W0221
186186
187187 # The unpacking must also account for non-rigid shapes
188188 if is_rigid (nshape ):
189+ assert nshape is not None
189190 buffer = numpy .empty ((draws , * nshape ), dtype )
190191 else :
191- buffer = numpy .repeat ( None , draws )
192+ buffer = numpy .array ([ None ] * draws )
192193 for d , (vals ,) in enumerate (data ):
193194 buffer [d ] = numpy .asarray (vals , dtype )
194195 return buffer
@@ -228,23 +229,21 @@ def __init__(
228229 self .created_at = created_at
229230 # We need handles on the chains to commit their batched inserts
230231 # before returning them to callers of `.get_chains()`.
231- self ._chains = None
232+ self ._chains : List [ ClickHouseChain ] = []
232233 super ().__init__ (meta )
233234
234235 def init_chain (self , chain_number : int ) -> ClickHouseChain :
235236 cmeta = ChainMeta (self .meta .rid , chain_number )
236237 create_chain_table (self ._client , cmeta , self .meta )
237238 chain = ClickHouseChain (cmeta , self .meta , client = self ._client_fn ())
238- if self ._chains is None :
239- self ._chains = []
240239 self ._chains .append (chain )
241240 return chain
242241
243- def get_chains (self ) -> Tuple [ClickHouseChain ]:
242+ def get_chains (self ) -> Tuple [ClickHouseChain , ... ]:
244243 # Preferably return existing handles on chains that might have
245244 # uncommitted inserts pending.
246245 if self ._chains :
247- return self ._chains
246+ return tuple ( self ._chains )
248247
249248 # Otherwise fetch existing chains from the DB.
250249 chains = []
@@ -274,13 +273,14 @@ def __init__(
274273 """
275274 if client is None and client_fn is None :
276275 raise ValueError ("Either a `client` or a `client_fn` must be provided." )
277- self ._client_fn = client_fn
278- self ._client = client
279276
280277 if client_fn is None :
281- self . _client_fn = lambda : client
278+ client_fn = lambda : client
282279 if client is None :
283- self ._client = self ._client_fn ()
280+ client = client_fn ()
281+
282+ self ._client_fn : Callable [[], clickhouse_driver .Client ] = client_fn
283+ self ._client : clickhouse_driver .Client = client
284284
285285 create_runs_table (self ._client )
286286 super ().__init__ ()
0 commit comments