Skip to content

Commit 291c489

Browse files
Merge pull request #20 from michaelosthege/speedtests
Add speedtests Closes #17
2 parents 7c0378a + 5bd8a6d commit 291c489

File tree

4 files changed

+133
-13
lines changed

4 files changed

+133
-13
lines changed

mcbackend/backends/clickhouse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
*,
104104
client: clickhouse_driver.Client,
105105
insert_interval: int = 1,
106+
insert_every: int = 500,
106107
draw_idx: int = 0,
107108
):
108109
self._draw_idx = draw_idx
@@ -113,6 +114,7 @@ def __init__(
113114
self._insert_queue = []
114115
self._last_insert = time.time()
115116
self._insert_interval = insert_interval
117+
self._insert_every = insert_every
116118
super().__init__(cmeta, rmeta)
117119

118120
def append(
@@ -126,7 +128,10 @@ def append(
126128
self._insert_query = f"INSERT INTO {self.cid} ({names}) VALUES"
127129
self._insert_queue.append(params)
128130

129-
if time.time() - self._last_insert > self._insert_interval:
131+
if (
132+
len(self._insert_queue) >= self._insert_every
133+
or time.time() - self._last_insert > self._insert_interval
134+
):
130135
self._commit()
131136
return
132137

mcbackend/test_backend_clickhouse.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
import pandas
88
import pytest
99

10-
from mcbackend.meta import ChainMeta, RunMeta, Variable
11-
12-
from .backends.clickhouse import (
10+
from mcbackend.backends.clickhouse import (
1311
ClickHouseBackend,
1412
ClickHouseChain,
1513
ClickHouseRun,
1614
create_chain_table,
1715
create_runs_table,
1816
)
19-
from .core import Chain, Run, chain_id
20-
from .test_utils import CheckBehavior, make_runmeta
17+
from mcbackend.core import Chain, Run, chain_id
18+
from mcbackend.meta import ChainMeta, RunMeta, Variable
19+
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta
2120

2221
try:
2322
client = clickhouse_driver.Client("localhost")
@@ -42,7 +41,7 @@ def fully_initialized(
4241
condition=not HAS_REAL_DB,
4342
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
4443
)
45-
class TestClickHouseBackend(CheckBehavior):
44+
class TestClickHouseBackend(CheckBehavior, CheckPerformance):
4645
cls_backend = ClickHouseBackend
4746
cls_run = ClickHouseRun
4847
cls_chain = ClickHouseChain
@@ -155,3 +154,9 @@ def test_insert_draw(self):
155154
numpy.testing.assert_array_equal(v2, draw["v2"])
156155
numpy.testing.assert_array_equal(v3, draw["v3"])
157156
pass
157+
158+
159+
if __name__ == "__main__":
160+
tc = TestClickHouseBackend()
161+
df = tc.run_all_benchmarks()
162+
print(df)

mcbackend/test_backend_numpy.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import hagelkorn
44
import numpy
55

6+
from mcbackend.backends.numpy import NumPyBackend, NumPyChain, NumPyRun
7+
from mcbackend.core import RunMeta
68
from mcbackend.meta import Variable
9+
from mcbackend.test_utils import CheckBehavior, CheckPerformance
710

8-
from .backends.numpy import NumPyBackend, NumPyChain, NumPyRun
9-
from .core import RunMeta
10-
from .test_utils import CheckBehavior
1111

12-
13-
class TestNumPyBackend(CheckBehavior):
12+
class TestNumPyBackend(CheckBehavior, CheckPerformance):
1413
cls_backend = NumPyBackend
1514
cls_run = NumPyRun
1615
cls_chain = NumPyChain
@@ -77,3 +76,9 @@ def test_growing(self):
7776
assert chain.get_draws("A").shape == (22, 2)
7877
assert chain.get_draws("B").shape == (22,)
7978
pass
79+
80+
81+
if __name__ == "__main__":
82+
tc = TestNumPyBackend()
83+
df = tc.run_all_benchmarks()
84+
print(df)

mcbackend/test_utils.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import random
2+
import time
3+
from dataclasses import dataclass
24
from typing import Sequence
35

46
import arviz
57
import hagelkorn
68
import numpy
9+
import pandas
710
import pytest
811

12+
import mcbackend
913
from mcbackend.meta import ChainMeta, DataVariable, RunMeta, Variable
1014
from mcbackend.npproto import utils
1115

@@ -238,14 +242,115 @@ def test__to_inferencedata(self):
238242
pass
239243

240244

241-
class CheckPerformance(BaseBackendTest):
245+
@dataclass
246+
class AppendSpeed:
247+
draws_per_second: float
248+
bytes_per_draw: float
249+
250+
@property
251+
def mib_per_second(self) -> float:
252+
return self.draws_per_second * self.bytes_per_draw / 1024 / 1024
253+
254+
def __str__(self):
255+
return f"{self.mib_per_second:.1f} MiB/s ({self.draws_per_second:.1f} draws/s)"
256+
257+
258+
def run_chain(run: Run, chain_number: int = 0, tmax: float = 10) -> AppendSpeed:
259+
"""Append with max speed to one chain for `tmax` seconds."""
260+
draw = make_draw(run.meta.variables)
261+
bytes_per_draw = sum(v.size * v.itemsize for v in draw.values())
262+
263+
chain = run.init_chain(chain_number)
264+
t_start = time.time()
265+
d = 0
266+
last_update = t_start
267+
while time.time() - t_start < tmax:
268+
chain.append(draw)
269+
d += 1
270+
now = time.time()
271+
if now - last_update > 1:
272+
print(f"Inserted {d} draws")
273+
last_update = now
274+
275+
assert len(chain) == d
276+
t_end = time.time()
277+
dps = d / (t_end - t_start)
278+
return AppendSpeed(dps, bytes_per_draw)
279+
280+
281+
class BackendBenchmark:
282+
"""A collection of backend benchmarking methods."""
283+
284+
backend: mcbackend.Backend
285+
286+
def run_all_benchmarks(self) -> pandas.DataFrame:
287+
"""Runs each benchmark method and summarizes the results in a DataFrame."""
288+
df = pandas.DataFrame(
289+
columns=["title", "bytes_per_draw", "append_speed", "description"]
290+
).set_index("title")
291+
for attr in dir(BackendBenchmark):
292+
meth = getattr(self, attr, None)
293+
if callable(meth) and meth.__name__.startswith("measure_"):
294+
try:
295+
self.setup_method(meth)
296+
except TypeError:
297+
pass
298+
print(f"Running {meth.__name__}")
299+
speed = meth()
300+
df.loc[meth.__name__[8:], ["bytes_per_draw", "append_speed", "description"]] = (
301+
speed.bytes_per_draw,
302+
str(speed),
303+
meth.__doc__,
304+
)
305+
return df
306+
307+
def measure_many_draws(self) -> AppendSpeed:
308+
"""One chain of (), (3,) and (5,2) float32 variables."""
309+
rmeta = RunMeta(
310+
rid=hagelkorn.random(),
311+
variables=[
312+
Variable("v1", "float32", []),
313+
Variable("v2", "float32", list((3,))),
314+
Variable("v3", "float32", [5, 2]),
315+
],
316+
)
317+
return run_chain(self.backend.init_run(rmeta))
318+
319+
def measure_many_variables(self) -> AppendSpeed:
320+
"""One chain with 300 variables of shapes (), (3,) and (5,2)."""
321+
rmeta = RunMeta(
322+
rid=hagelkorn.random(),
323+
variables=[Variable(f"v{v}", "float32", [5, 2][: v % 2]) for v in range(300)],
324+
)
325+
return run_chain(self.backend.init_run(rmeta))
326+
327+
def measure_big_variables(self) -> AppendSpeed:
328+
"""One chain with 3 variables of shapes (100,), (1000,) and (100, 100)."""
329+
rmeta = RunMeta(
330+
rid=hagelkorn.random(),
331+
variables=[
332+
Variable("v1", "float32", list((100,))),
333+
Variable("v2", "float32", list((1000,))),
334+
Variable("v3", "float32", list((100, 100))),
335+
],
336+
)
337+
return run_chain(self.backend.init_run(rmeta))
338+
339+
340+
class CheckPerformance(BaseBackendTest, BackendBenchmark):
242341
"""Checks that the backend is reasonably fast via various high-load tests."""
243342

244343
def test__many_draws(self):
344+
speed = self.measure_many_draws()
345+
assert speed.draws_per_second > 5000 or speed.mib_per_second > 1
245346
pass
246347

247348
def test__many_variables(self):
349+
speed = self.measure_many_variables()
350+
assert speed.draws_per_second > 500 or speed.mib_per_second > 5
248351
pass
249352

250353
def test__big_variables(self):
354+
speed = self.measure_big_variables()
355+
assert speed.draws_per_second > 500 or speed.mib_per_second > 5
251356
pass

0 commit comments

Comments
 (0)