11import json
2- import os
3- import warnings
42from dataclasses import dataclass
5- from functools import (
6- cache ,
7- wraps ,
8- )
93from typing import Any
104
115import pytest
126
137
14- @cache
15- def _get_benchmark_data () -> dict [tuple [str , str ], dict [str , Any ]]:
16- benchmark_output_file = os .getenv ("BENCHMARK_VALIDATE_FILE" )
17- if benchmark_output_file is None :
18- raise RuntimeError ("Environment variable BENCHMARK_VALIDATE_FILE is not set." )
19-
20- stats = {}
21- with open (benchmark_output_file ) as f :
22- data = json .load (f )
23- for benchmark in data ["benchmarks" ]:
24- name = benchmark ["name" ].split ("[" )[0 ] # Remove params from the name
25- params = json .dumps (benchmark ["params" ], sort_keys = True )
26- stats [name , params ] = benchmark ["stats" ]
27-
28- return stats
29-
30-
318@dataclass
329class PerformanceTestCaseSpec :
3310 fn_name : str
@@ -44,6 +21,11 @@ def get_params_for_parametrize(self):
4421 def get_params_json (self ):
4522 return json .dumps (self .params , sort_keys = True )
4623
24+ def get_params_human (self ):
25+ if all (type (value ) in [float , int ] for value in self .params .values ()):
26+ return ", " .join (f"{ key } ={ value } " for key , value in sorted (self .params .items ()))
27+ return self .get_params_json ()
28+
4729
4830def expected_benchmark (* multiple_cases : dict , ** single_case : dict ):
4931 def wrapper (fn ):
@@ -59,7 +41,7 @@ def wrapper(fn):
5941
6042 if case_param_keys != param_keys :
6143 raise ValueError (
62- "All expected_benchmark decorators must have the same parameter keys."
44+ "All listed cases in expected_benchmark must have the same parameter keys."
6345 f"Expected { param_keys } , got { case_param_keys } "
6446 )
6547
@@ -73,93 +55,11 @@ def wrapper(fn):
7355 )
7456 )
7557
76- if not os .getenv ("BENCHMARK_VALIDATE_FILE" ):
77- pytest .mark .parametrize (
78- "," .join (param_keys ),
79- [spec .get_params_for_parametrize () for spec in specs ],
80- )(fn )
81- return fn
82-
83- performance_factor = float (os .getenv ("BENCHMARK_PERFORMANCE_FACTOR" , "1.0" ))
84-
85- @wraps (fn )
86- def validation (* args , ** kwargs ):
87- # Find the matching spec
88- spec : PerformanceTestCaseSpec | None = None
89- for case in specs :
90- if all (kwargs .get (k ) == v for k , v in case .params .items ()):
91- spec = case
92- break
93-
94- assert spec is not None , "No matching performance case found for the given parameters."
95-
96- # Extract the actual parameters used in this test run
97- if spec .min_p0 is None or spec .max_p80 is None or spec .max_p100 is None :
98- warnings .warn ("Benchmark thresholds not set, skipping validation." , category = UserWarning )
99- return
100-
101- perf_data = _get_benchmark_data ()
102-
103- assert spec .fn_name , spec .get_params_json () in perf_data
104- stats = perf_data [spec .fn_name , spec .get_params_json ()]
105-
106- times = sorted (stats ["data" ])
107- p0 = times [0 ]
108- p80 = times [int (len (times ) * 0.8 )]
109- p100 = times [- 1 ]
110-
111- adjusted_min_p0 = spec .min_p0 * performance_factor
112- adjusted_max_p80 = spec .max_p80 * performance_factor
113- adjusted_max_p100 = spec .max_p100 * performance_factor
114-
115- p0_marker = "✓" if p0 >= adjusted_min_p0 else "✗"
116- p80_marker = "✓" if p80 <= adjusted_max_p80 else "✗"
117- p100_marker = "✓" if p100 <= adjusted_max_p100 else "✗"
118-
119- params_human = ", " .join (f"{ k } ={ v !r} " for k , v in spec .params .items ())
120- detailed_msg = f"""
121-
122- Benchmark '{ spec .fn_name } ' with params { params_human } results:
123-
124- { p0_marker } 0th percentile: { p0 :.3f} s
125- Unadjusted min_p0: { spec .min_p0 :.3f} s
126- Adjusted (*) min_p0: { adjusted_min_p0 :.3f} s
127-
128- { p80_marker } 80th percentile: { p80 :.3f} s
129- Unadjusted max_p80: { spec .max_p80 :.3f} s
130- Adjusted (*) max_p80: { adjusted_max_p80 :.3f} s
131-
132- { p100_marker } 100th percentile: { p100 :.3f} s
133- Unadjusted max_p100: { spec .max_p100 :.3f} s
134- Adjusted (*) max_p100: { adjusted_max_p100 :.3f} s
135-
136- (*) Use the environment variable "BENCHMARK_PERFORMANCE_FACTOR" to adjust the thresholds.
137-
138- BENCHMARK_PERFORMANCE_FACTOR=1.0 (default) is meant to represent GitHub Actions performance.
139- Decrease this factor if your local machine is faster than GitHub Actions.
140-
141- """
142-
143- if performance_factor == 1.0 :
144- adjusted_min_p0_str = f"{ adjusted_min_p0 :.3f} "
145- adjusted_max_p80_str = f"{ adjusted_max_p80 :.3f} "
146- adjusted_max_p100_str = f"{ adjusted_max_p100 :.3f} "
147- else :
148- adjusted_min_p0_str = f"{ adjusted_min_p0 :.3f} (= { spec .min_p0 :.3f} * { performance_factor } )"
149- adjusted_max_p80_str = f"{ adjusted_max_p80 :.3f} (= { spec .max_p80 :.3f} * { performance_factor } )"
150- adjusted_max_p100_str = f"{ adjusted_max_p100 :.3f} (= { spec .max_p100 :.3f} * { performance_factor } )"
151-
152- assert p0 >= adjusted_min_p0 , f"p0 { p0 :.3f} is less than expected { adjusted_min_p0_str } " + detailed_msg
153- assert p80 <= adjusted_max_p80 , f"p80 { p80 :.3f} is more than expected { adjusted_max_p80_str } " + detailed_msg
154- assert p100 <= adjusted_max_p100 , (
155- f"p100 { p100 :.3f} is more than expected { adjusted_max_p100_str } " + detailed_msg
156- )
157-
15858 pytest .mark .parametrize (
15959 "," .join (param_keys ),
16060 [spec .get_params_for_parametrize () for spec in specs ],
161- )(validation )
162-
163- return validation
61+ )(fn )
62+ fn . __expected_benchmark_specs = specs
63+ return fn
16464
16565 return wrapper
0 commit comments