Skip to content

Commit 8faf25e

Browse files
Use smart WHERE clauses in ClickHouseChain._get_rows
Closes #50
1 parent e17e5b4 commit 8faf25e

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

mcbackend/backends/clickhouse.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,46 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
9393
return client.execute(query)
9494

9595

96+
def where_slice(slc: slice, imax: int, col="_draw_idx") -> Tuple[str, bool]:
97+
"""Creates a WHERE clause to select rows according to a Python slice.
98+
99+
Parameters
100+
----------
101+
slc : slice
102+
A slice object to apply.
103+
imax : int
104+
End of the range to which the slice is applied.
105+
A `slice(None)` will return this many rows.
106+
col : str
107+
Name of the primary key column.
108+
Assumed to start at 0 with increments of 1.
109+
110+
Returns
111+
-------
112+
where : str
113+
WHERE clause for the query.
114+
reverse : bool
115+
If True the query result must be reversed because
116+
the slice had a backwards direction.
117+
"""
118+
# Determine non-negative slice indices
119+
istart, istop, istep = slc.indices(imax)
120+
if istep < 0:
121+
istop, istart = istart + 1, istop + 1
122+
reverse = True
123+
else:
124+
reverse = False
125+
126+
# Aggregate conditions
127+
conds = []
128+
if istart > 0:
129+
conds.append(f"{col}>={istart}")
130+
conds.append(f"{col}<{istop}")
131+
if istep != 1:
132+
conds.append(f"modulo({col} - {istart}, {abs(istep)}) == 0")
133+
return "WHERE " + " AND ".join(conds), reverse
134+
135+
96136
class ClickHouseChain(Chain):
97137
"""Represents an MCMC chain stored in ClickHouse."""
98138

@@ -174,8 +214,11 @@ def _get_rows(
174214
slc: slice = slice(None),
175215
) -> numpy.ndarray:
176216
self._commit()
177-
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid};")
217+
where, reverse = where_slice(slc, self._draw_idx)
218+
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid} {where};")
178219
draws = len(data)
220+
if reverse:
221+
data = reversed(data)
179222

180223
# Without draws return empty arrays of the correct shape/dtype
181224
if not draws:
@@ -198,9 +241,9 @@ def _get_rows(
198241
# To circumvent NumPy issue #19113
199242
arr = numpy.empty(draws, dtype=object)
200243
arr[:] = buffer
201-
return arr[slc]
244+
return arr
202245
# Otherwise (identical shapes) we can collapse into one ndarray
203-
return numpy.asarray(buffer, dtype=dtype)[slc]
246+
return numpy.asarray(buffer, dtype=dtype)
204247

205248
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
206249
var = self.variables[var_name]

0 commit comments

Comments
 (0)