diff --git a/src/perfplot/_main.py b/src/perfplot/_main.py index 5e10e93..cfcfb21 100644 --- a/src/perfplot/_main.py +++ b/src/perfplot/_main.py @@ -258,7 +258,7 @@ def __next__(self): raise RuntimeError("Measured 0 ns for a function call. Try again?") if self.equality_check: - if k == 0: + if reference is None: reference = val else: try: @@ -519,7 +519,7 @@ def callback(): for i in range(len(n_range)): timings_s[i] = next(b) - # override n_rane in case it got overridden in next() + # override n_range in case it got overridden in next() n_range = b.n_range if show_progress: @@ -539,17 +539,17 @@ def callback(): def plot( *args, time_unit: str = "s", + relative_to: int | None = None, logx: Literal["auto"] | bool = "auto", logy: Literal["auto"] | bool = "auto", - relative_to: int | None = None, **kwargs, ): out = bench(*args, **kwargs) out.plot( time_unit=time_unit, + relative_to=relative_to, logx=logx, logy=logy, - relative_to=relative_to, ) @@ -575,9 +575,9 @@ def save( transparent=True, *args, time_unit: str = "s", + relative_to: int | None = None, logx: bool | Literal["auto"] = "auto", logy: bool | Literal["auto"] = "auto", - relative_to: int | None = None, **kwargs, ): out = bench(*args, **kwargs) @@ -585,7 +585,7 @@ def save( filename, transparent, time_unit=time_unit, + relative_to=relative_to, logx=logx, logy=logy, - relative_to=relative_to, ) diff --git a/tests/test_perfplot.py b/tests/test_perfplot.py index b35f8ca..7fe004b 100644 --- a/tests/test_perfplot.py +++ b/tests/test_perfplot.py @@ -150,3 +150,24 @@ def times_reversed(a, b): perfplot.show( setup=setup, kernels=[times, times_reversed], n_range=[2**k for k in range(3)] ) + + +def test_exceed_time(): + """other functions should be checked for validity even if the first timed out""" + import time + def setup(n): + return np.random.rand(n) + + def exceed_time(a): + time.sleep(0.5) + return 1 + + def in_time(a): + return 1 + + def in_time2(a): + return 1 + + b = perfplot.bench( + setup=setup, kernels=[exceed_time, in_time, in_time2], n_range=[2**k for k in range(3)], max_time=0.1 + )