Skip to content

Commit cb63987

Browse files
committed
Add tests for composition of formula receivers
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 0f463a3 commit cb63987

File tree

1 file changed

+266
-3
lines changed

1 file changed

+266
-3
lines changed

tests/timeseries/test_formula_engine.py

Lines changed: 266 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55

66
import asyncio
77
from datetime import datetime
8-
from typing import Dict, List, Optional, Tuple
8+
from typing import Callable, Dict, List, Optional, Tuple, Union
99

10-
from frequenz.channels import Broadcast
10+
from frequenz.channels import Broadcast, Receiver
1111

1212
from frequenz.sdk.timeseries import Sample
13-
from frequenz.sdk.timeseries.logical_meter._formula_engine import FormulaBuilder
13+
from frequenz.sdk.timeseries.logical_meter._formula_engine import (
14+
FormulaBuilder,
15+
FormulaEngine,
16+
FormulaReceiver,
17+
HigherOrderFormulaBuilder,
18+
)
1419
from frequenz.sdk.timeseries.logical_meter._tokenizer import Token, Tokenizer, TokenType
1520

1621

@@ -79,6 +84,7 @@ async def run_test(
7984
next_val = await engine._apply() # pylint: disable=protected-access
8085
assert (next_val).value == io_output
8186
tests_passed += 1
87+
await engine._stop() # pylint: disable=protected-access
8288
assert tests_passed == len(io_pairs)
8389

8490
async def test_simple(self) -> None:
@@ -278,3 +284,260 @@ async def test_nones_are_not_zeros(self) -> None:
278284
],
279285
False,
280286
)
287+
288+
289+
class TestFormulaChannel:
290+
"""Tests for formula channels."""
291+
292+
def make_engine(self, stream_id: int, data: Receiver[Sample]) -> FormulaEngine:
293+
"""Make a basic FormulaEngine."""
294+
name = f"#{stream_id}"
295+
builder = FormulaBuilder(name)
296+
builder.push_metric(
297+
name,
298+
data,
299+
nones_are_zeros=False,
300+
)
301+
return builder.build()
302+
303+
async def run_test( # pylint: disable=too-many-locals
304+
self,
305+
num_items: int,
306+
make_builder: Union[
307+
Callable[
308+
[FormulaReceiver, FormulaReceiver, FormulaReceiver],
309+
HigherOrderFormulaBuilder,
310+
],
311+
Callable[
312+
[FormulaReceiver, FormulaReceiver, FormulaReceiver, FormulaReceiver],
313+
HigherOrderFormulaBuilder,
314+
],
315+
],
316+
io_pairs: List[Tuple[List[Optional[float]], Optional[float]]],
317+
nones_are_zeros: bool = False,
318+
) -> None:
319+
"""Run a test with the specs provided."""
320+
channels = [Broadcast[Sample](str(ctr)) for ctr in range(num_items)]
321+
l1_engines = [
322+
self.make_engine(ctr, channels[ctr].new_receiver())
323+
for ctr in range(num_items)
324+
]
325+
builder = make_builder(*[e.new_receiver() for e in l1_engines])
326+
engine = builder.build("l2 formula", nones_are_zeros)
327+
result_chan = engine.new_receiver()
328+
329+
now = datetime.now()
330+
tests_passed = 0
331+
for io_pair in io_pairs:
332+
io_input, io_output = io_pair
333+
assert all(
334+
await asyncio.gather(
335+
*[
336+
chan.new_sender().send(Sample(now, value))
337+
for chan, value in zip(channels, io_input)
338+
]
339+
)
340+
)
341+
next_val = await result_chan.receive()
342+
assert next_val.value == io_output
343+
tests_passed += 1
344+
await engine._stop() # pylint: disable=protected-access
345+
assert tests_passed == len(io_pairs)
346+
347+
async def test_simple(self) -> None:
348+
"""Test simple formulas."""
349+
await self.run_test(
350+
3,
351+
lambda c2, c4, c5: c2 - c4 + c5,
352+
[
353+
([10.0, 12.0, 15.0], 13.0),
354+
([15.0, 17.0, 20.0], 18.0),
355+
],
356+
)
357+
await self.run_test(
358+
3,
359+
lambda c2, c4, c5: c2 + c4 - c5,
360+
[
361+
([10.0, 12.0, 15.0], 7.0),
362+
([15.0, 17.0, 20.0], 12.0),
363+
],
364+
)
365+
await self.run_test(
366+
3,
367+
lambda c2, c4, c5: c2 * c4 + c5,
368+
[
369+
([10.0, 12.0, 15.0], 135.0),
370+
([15.0, 17.0, 20.0], 275.0),
371+
],
372+
)
373+
await self.run_test(
374+
3,
375+
lambda c2, c4, c5: c2 * c4 / c5,
376+
[
377+
([10.0, 12.0, 15.0], 8.0),
378+
([15.0, 17.0, 20.0], 12.75),
379+
],
380+
)
381+
await self.run_test(
382+
3,
383+
lambda c2, c4, c5: c2 / c4 - c5,
384+
[
385+
([6.0, 12.0, 15.0], -14.5),
386+
([15.0, 20.0, 20.0], -19.25),
387+
],
388+
)
389+
await self.run_test(
390+
3,
391+
lambda c2, c4, c5: c2 - c4 - c5,
392+
[
393+
([6.0, 12.0, 15.0], -21.0),
394+
([15.0, 20.0, 20.0], -25.0),
395+
],
396+
)
397+
await self.run_test(
398+
3,
399+
lambda c2, c4, c5: c2 + c4 + c5,
400+
[
401+
([6.0, 12.0, 15.0], 33.0),
402+
([15.0, 20.0, 20.0], 55.0),
403+
],
404+
)
405+
await self.run_test(
406+
3,
407+
lambda c2, c4, c5: c2 / c4 / c5,
408+
[
409+
([30.0, 3.0, 5.0], 2.0),
410+
([15.0, 3.0, 2.0], 2.5),
411+
],
412+
)
413+
414+
async def test_compound(self) -> None:
415+
"""Test compound formulas."""
416+
await self.run_test(
417+
4,
418+
lambda c2, c4, c5, c6: c2 + c4 - c5 * c6,
419+
[
420+
([10.0, 12.0, 15.0, 2.0], -8.0),
421+
([15.0, 17.0, 20.0, 1.5], 2.0),
422+
],
423+
)
424+
await self.run_test(
425+
4,
426+
lambda c2, c4, c5, c6: c2 + (c4 - c5) * c6,
427+
[
428+
([10.0, 12.0, 15.0, 2.0], 4.0),
429+
([15.0, 17.0, 20.0, 1.5], 10.5),
430+
],
431+
)
432+
await self.run_test(
433+
4,
434+
lambda c2, c4, c5, c6: c2 + (c4 - c5 * c6),
435+
[
436+
([10.0, 12.0, 15.0, 2.0], -8.0),
437+
([15.0, 17.0, 20.0, 1.5], 2.0),
438+
],
439+
)
440+
await self.run_test(
441+
4,
442+
lambda c2, c4, c5, c6: c2 + (c4 - c5 - c6),
443+
[
444+
([10.0, 12.0, 15.0, 2.0], 5.0),
445+
([15.0, 17.0, 20.0, 1.5], 10.5),
446+
],
447+
)
448+
await self.run_test(
449+
4,
450+
lambda c2, c4, c5, c6: c2 + c4 - c5 - c6,
451+
[
452+
([10.0, 12.0, 15.0, 2.0], 5.0),
453+
([15.0, 17.0, 20.0, 1.5], 10.5),
454+
],
455+
)
456+
await self.run_test(
457+
4,
458+
lambda c2, c4, c5, c6: c2 + c4 - (c5 - c6),
459+
[
460+
([10.0, 12.0, 15.0, 2.0], 9.0),
461+
([15.0, 17.0, 20.0, 1.5], 13.5),
462+
],
463+
)
464+
await self.run_test(
465+
4,
466+
lambda c2, c4, c5, c6: (c2 + c4 - c5) * c6,
467+
[
468+
([10.0, 12.0, 15.0, 2.0], 14.0),
469+
([15.0, 17.0, 20.0, 1.5], 18.0),
470+
],
471+
)
472+
await self.run_test(
473+
4,
474+
lambda c2, c4, c5, c6: (c2 + c4 - c5) / c6,
475+
[
476+
([10.0, 12.0, 15.0, 2.0], 3.5),
477+
([15.0, 17.0, 20.0, 1.5], 8.0),
478+
],
479+
)
480+
await self.run_test(
481+
4,
482+
lambda c2, c4, c5, c6: c2 + c4 - (c5 / c6),
483+
[
484+
([10.0, 12.0, 15.0, 2.0], 14.5),
485+
([15.0, 17.0, 20.0, 5.0], 28.0),
486+
],
487+
)
488+
489+
async def test_nones_are_zeros(self) -> None:
490+
"""Test that `None`s are treated as zeros when configured."""
491+
await self.run_test(
492+
3,
493+
lambda c2, c4, c5: c2 - c4 + c5,
494+
[
495+
([10.0, 12.0, 15.0], 13.0),
496+
([None, 12.0, 15.0], 3.0),
497+
([10.0, None, 15.0], 25.0),
498+
([15.0, 17.0, 20.0], 18.0),
499+
([15.0, None, None], 15.0),
500+
],
501+
True,
502+
)
503+
504+
await self.run_test(
505+
4,
506+
lambda c2, c4, c5, c6: c2 + c4 - (c5 * c6),
507+
[
508+
([10.0, 12.0, 15.0, 2.0], -8.0),
509+
([10.0, 12.0, 15.0, None], 22.0),
510+
([10.0, None, 15.0, 2.0], -20.0),
511+
([15.0, 17.0, 20.0, 5.0], -68.0),
512+
([15.0, 17.0, None, 5.0], 32.0),
513+
],
514+
True,
515+
)
516+
517+
async def test_nones_are_not_zeros(self) -> None:
518+
"""Test that calculated values are `None` on input `None`s."""
519+
await self.run_test(
520+
3,
521+
lambda c2, c4, c5: c2 - c4 + c5,
522+
[
523+
([10.0, 12.0, 15.0], 13.0),
524+
([None, 12.0, 15.0], None),
525+
([10.0, None, 15.0], None),
526+
([15.0, 17.0, 20.0], 18.0),
527+
([15.0, None, None], None),
528+
],
529+
False,
530+
)
531+
532+
await self.run_test(
533+
4,
534+
lambda c2, c4, c5, c6: c2 + c4 - (c5 * c6),
535+
[
536+
([10.0, 12.0, 15.0, 2.0], -8.0),
537+
([10.0, 12.0, 15.0, None], None),
538+
([10.0, None, 15.0, 2.0], None),
539+
([15.0, 17.0, 20.0, 5.0], -68.0),
540+
([15.0, 17.0, None, 5.0], None),
541+
],
542+
False,
543+
)

0 commit comments

Comments
 (0)