@@ -93,6 +93,46 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
93
93
return client .execute (query )
94
94
95
95
96
+ def where_slice (slc : slice , imax : int , col = "_draw_idx" ) -> Tuple [str , bool ]:
97
+ """Creates a WHERE clause to select rows according to a Python slice.
98
+
99
+ Parameters
100
+ ----------
101
+ slc : slice
102
+ A slice object to apply.
103
+ imax : int
104
+ End of the range to which the slice is applied.
105
+ A `slice(None)` will return this many rows.
106
+ col : str
107
+ Name of the primary key column.
108
+ Assumed to start at 0 with increments of 1.
109
+
110
+ Returns
111
+ -------
112
+ where : str
113
+ WHERE clause for the query.
114
+ reverse : bool
115
+ If True the query result must be reversed because
116
+ the slice had a backwards direction.
117
+ """
118
+ # Determine non-negative slice indices
119
+ istart , istop , istep = slc .indices (imax )
120
+ if istep < 0 :
121
+ istop , istart = istart + 1 , istop + 1
122
+ reverse = True
123
+ else :
124
+ reverse = False
125
+
126
+ # Aggregate conditions
127
+ conds = []
128
+ if istart > 0 :
129
+ conds .append (f"{ col } >={ istart } " )
130
+ conds .append (f"{ col } <{ istop } " )
131
+ if istep != 1 :
132
+ conds .append (f"modulo({ col } - { istart } , { abs (istep )} ) == 0" )
133
+ return "WHERE " + " AND " .join (conds ), reverse
134
+
135
+
96
136
class ClickHouseChain (Chain ):
97
137
"""Represents an MCMC chain stored in ClickHouse."""
98
138
@@ -174,8 +214,11 @@ def _get_rows(
174
214
slc : slice = slice (None ),
175
215
) -> numpy .ndarray :
176
216
self ._commit ()
177
- data = self ._client .execute (f"SELECT (`{ var_name } `) FROM { self .cid } ;" )
217
+ where , reverse = where_slice (slc , self ._draw_idx )
218
+ data = self ._client .execute (f"SELECT (`{ var_name } `) FROM { self .cid } { where } ;" )
178
219
draws = len (data )
220
+ if reverse :
221
+ data = reversed (data )
179
222
180
223
# Without draws return empty arrays of the correct shape/dtype
181
224
if not draws :
@@ -198,9 +241,9 @@ def _get_rows(
198
241
# To circumvent NumPy issue #19113
199
242
arr = numpy .empty (draws , dtype = object )
200
243
arr [:] = buffer
201
- return arr [ slc ]
244
+ return arr
202
245
# Otherwise (identical shapes) we can collapse into one ndarray
203
- return numpy .asarray (buffer , dtype = dtype )[ slc ]
246
+ return numpy .asarray (buffer , dtype = dtype )
204
247
205
248
def get_draws (self , var_name : str , slc : slice = slice (None )) -> numpy .ndarray :
206
249
var = self .variables [var_name ]
0 commit comments