@@ -93,6 +93,46 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
9393 return client .execute (query )
9494
9595
96+ def where_slice (slc : slice , imax : int , col = "_draw_idx" ) -> Tuple [str , bool ]:
97+ """Creates a WHERE clause to select rows according to a Python slice.
98+
99+ Parameters
100+ ----------
101+ slc : slice
102+ A slice object to apply.
103+ imax : int
104+ End of the range to which the slice is applied.
105+ A `slice(None)` will return this many rows.
106+ col : str
107+ Name of the primary key column.
108+ Assumed to start at 0 with increments of 1.
109+
110+ Returns
111+ -------
112+ where : str
113+ WHERE clause for the query.
114+ reverse : bool
115+ If True the query result must be reversed because
116+ the slice had a backwards direction.
117+ """
118+ # Determine non-negative slice indices
119+ istart , istop , istep = slc .indices (imax )
120+ if istep < 0 :
121+ istop , istart = istart + 1 , istop + 1
122+ reverse = True
123+ else :
124+ reverse = False
125+
126+ # Aggregate conditions
127+ conds = []
128+ if istart > 0 :
129+ conds .append (f"{ col } >={ istart } " )
130+ conds .append (f"{ col } <{ istop } " )
131+ if istep != 1 :
132+ conds .append (f"modulo({ col } - { istart } , { abs (istep )} ) == 0" )
133+ return "WHERE " + " AND " .join (conds ), reverse
134+
135+
96136class ClickHouseChain (Chain ):
97137 """Represents an MCMC chain stored in ClickHouse."""
98138
@@ -174,8 +214,11 @@ def _get_rows(
174214 slc : slice = slice (None ),
175215 ) -> numpy .ndarray :
176216 self ._commit ()
177- data = self ._client .execute (f"SELECT (`{ var_name } `) FROM { self .cid } ;" )
217+ where , reverse = where_slice (slc , self ._draw_idx )
218+ data = self ._client .execute (f"SELECT (`{ var_name } `) FROM { self .cid } { where } ;" )
178219 draws = len (data )
220+ if reverse :
221+ data = reversed (data )
179222
180223 # Without draws return empty arrays of the correct shape/dtype
181224 if not draws :
@@ -198,9 +241,9 @@ def _get_rows(
198241 # To circumvent NumPy issue #19113
199242 arr = numpy .empty (draws , dtype = object )
200243 arr [:] = buffer
201- return arr [ slc ]
244+ return arr
202245 # Otherwise (identical shapes) we can collapse into one ndarray
203- return numpy .asarray (buffer , dtype = dtype )[ slc ]
246+ return numpy .asarray (buffer , dtype = dtype )
204247
205248 def get_draws (self , var_name : str , slc : slice = slice (None )) -> numpy .ndarray :
206249 var = self .variables [var_name ]
0 commit comments