@@ -541,3 +541,67 @@ async def test_nones_are_not_zeros(self) -> None:
541541 ],
542542 False ,
543543 )
544+
545+
546+ class TestFormulaAverager :
547+ """Tests for the formula step for calculating average."""
548+
549+ async def run_test (
550+ self ,
551+ components : List [str ],
552+ io_pairs : List [Tuple [List [Optional [float ]], Optional [float ]]],
553+ ) -> None :
554+ """Run a formula test."""
555+ channels : Dict [str , Broadcast [Sample ]] = {}
556+ streams : List [Tuple [str , Receiver [Sample ], bool ]] = []
557+ builder = FormulaBuilder ("test_averager" )
558+ for comp_id in components :
559+ if comp_id not in channels :
560+ channels [comp_id ] = Broadcast (comp_id )
561+ streams .append ((f"{ comp_id } " , channels [comp_id ].new_receiver (), False ))
562+
563+ builder .push_average (streams )
564+ engine = builder .build ()
565+
566+ now = datetime .now ()
567+ tests_passed = 0
568+ for io_pair in io_pairs :
569+ io_input , io_output = io_pair
570+ assert all (
571+ await asyncio .gather (
572+ * [
573+ chan .new_sender ().send (Sample (now , value ))
574+ for chan , value in zip (channels .values (), io_input )
575+ ]
576+ )
577+ )
578+ next_val = await engine ._apply () # pylint: disable=protected-access
579+ assert (next_val ).value == io_output
580+ tests_passed += 1
581+ await engine ._stop () # pylint: disable=protected-access
582+ assert tests_passed == len (io_pairs )
583+
584+ async def test_simple (self ) -> None :
585+ """Test simple formulas."""
586+ await self .run_test (
587+ ["#2" , "#4" , "#5" ],
588+ [
589+ ([10.0 , 12.0 , 14.0 ], 12.0 ),
590+ ([15.0 , 17.0 , 19.0 ], 17.0 ),
591+ ([11.1 , 11.1 , 11.1 ], 11.1 ),
592+ ],
593+ )
594+
595+ async def test_nones_are_skipped (self ) -> None :
596+ """Test that `None`s are skipped for computing the average."""
597+ await self .run_test (
598+ ["#2" , "#4" , "#5" ],
599+ [
600+ ([11.0 , 13.0 , 15.0 ], 13.0 ),
601+ ([None , 13.0 , 19.0 ], 16.0 ),
602+ ([12.2 , None , 22.2 ], 17.2 ),
603+ ([16.5 , 19.5 , None ], 18.0 ),
604+ ([None , 13.0 , None ], 13.0 ),
605+ ([None , None , None ], 0.0 ),
606+ ],
607+ )
0 commit comments