Skip to content

Commit ffb406a

Browse files
authored
chore: benchmark refactors and updates (#1043)
* chore: delete failed benchmarks and collect local execution time for notebook * simplify reading * update error handling. * update geometric mean
1 parent 1bfa598 commit ffb406a

File tree

2 files changed

+98
-74
lines changed

2 files changed

+98
-74
lines changed

noxfile.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import re
2424
import shutil
2525
import time
26-
import traceback
2726
from typing import Dict, List
2827
import warnings
2928

@@ -794,10 +793,6 @@ def notebook(session: nox.Session):
794793
*notebooks,
795794
)
796795

797-
# Shared flag using multiprocessing.Manager() to indicate if
798-
# any process encounters an error. This flag may be updated
799-
# across different processes.
800-
error_flag = multiprocessing.Manager().Value("i", False)
801796
processes = []
802797
for notebook in notebooks:
803798
args = (
@@ -808,8 +803,8 @@ def notebook(session: nox.Session):
808803
)
809804
if multi_process_mode:
810805
process = multiprocessing.Process(
811-
target=_run_process,
812-
args=(session, args, error_flag),
806+
target=session.run,
807+
args=args,
813808
)
814809
process.start()
815810
processes.append(process)
@@ -819,10 +814,6 @@ def notebook(session: nox.Session):
819814
else:
820815
session.run(*args)
821816

822-
for process in processes:
823-
process.join()
824-
825-
processes = []
826817
for notebook, regions in notebooks_reg.items():
827818
for region in regions:
828819
region_args = (
@@ -834,8 +825,8 @@ def notebook(session: nox.Session):
834825
)
835826
if multi_process_mode:
836827
process = multiprocessing.Process(
837-
target=_run_process,
838-
args=(session, region_args, error_flag),
828+
target=session.run,
829+
args=region_args,
839830
)
840831
process.start()
841832
processes.append(process)
@@ -847,11 +838,6 @@ def notebook(session: nox.Session):
847838

848839
for process in processes:
849840
process.join()
850-
851-
# Check the shared error flag and raise an exception if any process
852-
# reported an error
853-
if error_flag.value:
854-
raise Exception("Errors occurred in one or more subprocesses.")
855841
finally:
856842
# Prevent our notebook changes from getting checked in to git
857843
# accidentally.
@@ -868,15 +854,6 @@ def notebook(session: nox.Session):
868854
)
869855

870856

871-
def _run_process(session: nox.Session, args, error_flag):
872-
try:
873-
session.run(*args)
874-
except Exception:
875-
traceback_str = traceback.format_exc()
876-
print(traceback_str)
877-
error_flag.value = True
878-
879-
880857
@nox.session(python=DEFAULT_PYTHON_VERSION)
881858
def benchmark(session: nox.Session):
882859
session.install("-e", ".[all]")

scripts/run_and_publish_benchmark.py

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import json
1818
import os
1919
import pathlib
20+
import re
2021
import subprocess
2122
import sys
2223
import tempfile
23-
from typing import Dict, List, Union
24+
from typing import Dict, List, Tuple, Union
2425

2526
import numpy as np
2627
import pandas as pd
@@ -30,7 +31,7 @@
3031
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
3132

3233

33-
def run_benchmark_subprocess(args, log_env_name_var, filename=None, region=None):
34+
def run_benchmark_subprocess(args, log_env_name_var, file_path=None, region=None):
3435
"""
3536
Runs a benchmark subprocess with configured environment variables. Adjusts PYTHONPATH,
3637
sets region-specific BigQuery location, and logs environment variables.
@@ -48,36 +49,56 @@ def run_benchmark_subprocess(args, log_env_name_var, filename=None, region=None)
4849
if region:
4950
env["BIGQUERY_LOCATION"] = region
5051
env[LOGGING_NAME_ENV_VAR] = log_env_name_var
51-
subprocess.run(args, env=env, check=True)
52-
53-
54-
def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFrame:
52+
try:
53+
if file_path: # Notebooks
54+
duration_pattern = re.compile(r"(\d+\.\d+)s call")
55+
process = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, text=True)
56+
assert process.stdout is not None
57+
for line in process.stdout:
58+
print(line, end="")
59+
match = duration_pattern.search(line)
60+
if match:
61+
duration = match.group(1)
62+
with open(f"{file_path}.local_exec_time_seconds", "w") as f:
63+
f.write(f"{duration}\n")
64+
process.wait()
65+
if process.returncode != 0:
66+
raise subprocess.CalledProcessError(process.returncode, args)
67+
else: # Benchmarks
68+
file_path = log_env_name_var
69+
subprocess.run(args, env=env, check=True)
70+
except Exception:
71+
directory = pathlib.Path(file_path).parent
72+
for file in directory.glob(f"{pathlib.Path(file_path).name}.*"):
73+
if file.suffix != ".backup":
74+
print(f"Benchmark failed, deleting: {file}")
75+
file.unlink()
76+
error_file = directory / f"{pathlib.Path(file_path).name}.error"
77+
error_file.touch()
78+
79+
80+
def collect_benchmark_result(
81+
benchmark_path: str, iterations: int
82+
) -> Tuple[pd.DataFrame, Union[str, None]]:
5583
"""Generate a DataFrame report on HTTP queries, bytes processed, slot time and execution time from log files."""
5684
path = pathlib.Path(benchmark_path)
5785
try:
5886
results_dict: Dict[str, List[Union[int, float, None]]] = {}
5987
bytes_files = sorted(path.rglob("*.bytesprocessed"))
6088
millis_files = sorted(path.rglob("*.slotmillis"))
6189
bq_seconds_files = sorted(path.rglob("*.bq_exec_time_seconds"))
62-
6390
local_seconds_files = sorted(path.rglob("*.local_exec_time_seconds"))
64-
has_local_seconds = len(local_seconds_files) > 0
65-
66-
if has_local_seconds:
67-
if not (
68-
len(bytes_files)
69-
== len(millis_files)
70-
== len(local_seconds_files)
71-
== len(bq_seconds_files)
72-
):
73-
raise ValueError(
74-
"Mismatch in the number of report files for bytes, millis, and seconds."
75-
)
76-
else:
77-
if not (len(bytes_files) == len(millis_files) == len(bq_seconds_files)):
78-
raise ValueError(
79-
"Mismatch in the number of report files for bytes, millis, and seconds."
80-
)
91+
error_files = sorted(path.rglob("*.error"))
92+
93+
if not (
94+
len(bytes_files)
95+
== len(millis_files)
96+
== len(local_seconds_files)
97+
== len(bq_seconds_files)
98+
):
99+
raise ValueError(
100+
"Mismatch in the number of report files for bytes, millis, and seconds."
101+
)
81102

82103
for idx in range(len(bytes_files)):
83104
bytes_file = bytes_files[idx]
@@ -92,12 +113,11 @@ def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFra
92113
"File name mismatch among bytes, millis, and seconds reports."
93114
)
94115

95-
if has_local_seconds:
96-
local_seconds_file = local_seconds_files[idx]
97-
if filename != local_seconds_file.relative_to(path).with_suffix(""):
98-
raise ValueError(
99-
"File name mismatch among bytes, millis, and seconds reports."
100-
)
116+
local_seconds_file = local_seconds_files[idx]
117+
if filename != local_seconds_file.relative_to(path).with_suffix(""):
118+
raise ValueError(
119+
"File name mismatch among bytes, millis, and seconds reports."
120+
)
101121

102122
with open(bytes_file, "r") as file:
103123
lines = file.read().splitlines()
@@ -108,12 +128,9 @@ def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFra
108128
lines = file.read().splitlines()
109129
total_slot_millis = sum(int(line) for line in lines) / iterations
110130

111-
if has_local_seconds:
112-
with open(local_seconds_file, "r") as file:
113-
lines = file.read().splitlines()
114-
local_seconds = sum(float(line) for line in lines) / iterations
115-
else:
116-
local_seconds = None
131+
with open(local_seconds_file, "r") as file:
132+
lines = file.read().splitlines()
133+
local_seconds = sum(float(line) for line in lines) / iterations
117134

118135
with open(bq_seconds_file, "r") as file:
119136
lines = file.read().splitlines()
@@ -132,6 +149,7 @@ def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFra
132149
path.rglob("*.slotmillis"),
133150
path.rglob("*.local_exec_time_seconds"),
134151
path.rglob("*.bq_exec_time_seconds"),
152+
path.rglob("*.error"),
135153
):
136154
for log_file in files_to_remove:
137155
log_file.unlink()
@@ -170,13 +188,19 @@ def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFra
170188
f" bigquery execution time: {round(row['BigQuery_Execution_Time_Sec'], 1)} seconds"
171189
)
172190

173-
geometric_mean_queries = geometric_mean(benchmark_metrics["Query_Count"])
174-
geometric_mean_bytes = geometric_mean(benchmark_metrics["Bytes_Processed"])
175-
geometric_mean_slot_millis = geometric_mean(benchmark_metrics["Slot_Millis"])
176-
geometric_mean_local_seconds = geometric_mean(
191+
geometric_mean_queries = geometric_mean_excluding_zeros(
192+
benchmark_metrics["Query_Count"]
193+
)
194+
geometric_mean_bytes = geometric_mean_excluding_zeros(
195+
benchmark_metrics["Bytes_Processed"]
196+
)
197+
geometric_mean_slot_millis = geometric_mean_excluding_zeros(
198+
benchmark_metrics["Slot_Millis"]
199+
)
200+
geometric_mean_local_seconds = geometric_mean_excluding_zeros(
177201
benchmark_metrics["Local_Execution_Time_Sec"]
178202
)
179-
geometric_mean_bq_seconds = geometric_mean(
203+
geometric_mean_bq_seconds = geometric_mean_excluding_zeros(
180204
benchmark_metrics["BigQuery_Execution_Time_Sec"]
181205
)
182206

@@ -188,15 +212,33 @@ def collect_benchmark_result(benchmark_path: str, iterations: int) -> pd.DataFra
188212
f"Geometric mean of BigQuery execution time: {geometric_mean_bq_seconds} seconds---"
189213
)
190214

191-
return benchmark_metrics.reset_index().rename(columns={"index": "Benchmark_Name"})
215+
error_message = (
216+
"\n"
217+
+ "\n".join(
218+
[
219+
f"Failed: {error_file.relative_to(path).with_suffix('')}"
220+
for error_file in error_files
221+
]
222+
)
223+
if error_files
224+
else None
225+
)
226+
return (
227+
benchmark_metrics.reset_index().rename(columns={"index": "Benchmark_Name"}),
228+
error_message,
229+
)
192230

193231

194-
def geometric_mean(data):
232+
def geometric_mean_excluding_zeros(data):
195233
"""
196-
Calculate the geometric mean of a dataset, rounding the result to one decimal place.
197-
Returns NaN if the dataset is empty or contains only NaN values.
234+
Calculate the geometric mean of a dataset, excluding any zero values.
235+
Returns NaN if the dataset is empty, contains only NaN values, or if
236+
all non-NaN values are zeros.
237+
238+
The result is rounded to one decimal place.
198239
"""
199240
data = data.dropna()
241+
data = data[data != 0]
200242
if len(data) == 0:
201243
return np.nan
202244
log_data = np.log(data)
@@ -321,13 +363,15 @@ def run_notebook_benchmark(benchmark_file: str, region: str):
321363
"py.test",
322364
"--nbmake",
323365
"--nbmake-timeout=900", # 15 minutes
366+
"--durations=0",
367+
"--color=yes",
324368
]
325369
benchmark_args = (*pytest_command, benchmark_file)
326370

327371
run_benchmark_subprocess(
328372
args=benchmark_args,
329373
log_env_name_var=log_env_name_var,
330-
filename=export_file,
374+
file_path=export_file,
331375
region=region,
332376
)
333377

@@ -383,7 +427,7 @@ def main():
383427
args = parse_arguments()
384428

385429
if args.publish_benchmarks:
386-
benchmark_metrics = collect_benchmark_result(
430+
benchmark_metrics, error_message = collect_benchmark_result(
387431
args.publish_benchmarks, args.iterations
388432
)
389433
# Output results to CSV without specifying a location
@@ -412,6 +456,9 @@ def main():
412456
# intended for local testing where the default behavior is not to publish results.
413457
elif project := os.getenv("GCLOUD_BENCH_PUBLISH_PROJECT", ""):
414458
publish_to_bigquery(benchmark_metrics, args.notebook, project)
459+
460+
if error_message:
461+
raise Exception(error_message)
415462
elif args.notebook:
416463
run_notebook_benchmark(args.benchmark_path, args.region)
417464
else:

0 commit comments

Comments
 (0)