Skip to content

Commit 939aef5

Browse files
dorellangcopybara-github
authored andcommitted
Add flag to parse query times from logs for serverless DPB services.
PiperOrigin-RevId: 719033039
1 parent a26965d commit 939aef5

File tree

7 files changed

+310
-77
lines changed

7 files changed

+310
-77
lines changed

perfkitbenchmarker/dpb_service.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
"""
2121

2222
import abc
23-
from collections.abc import MutableMapping
23+
from collections.abc import Callable, MutableMapping
2424
import dataclasses
2525
import datetime
2626
import logging
2727
import os
2828
import shutil
2929
import tempfile
30-
from typing import Dict, List, Type
30+
from typing import Dict, List, Type, TypeAlias
3131

3232
from absl import flags
3333
import jinja2
@@ -164,18 +164,37 @@ class JobSubmissionError(errors.Benchmarks.RunError):
164164
pass
165165

166166

167+
FetchOutputFn: TypeAlias = Callable[[], tuple[str | None, str | None]]
168+
169+
167170
@dataclasses.dataclass
168171
class JobResult:
169-
"""Data class for the timing of a successful DPB job."""
172+
"""Data class for the timing of a successful DPB job.
173+
174+
Attributes:
175+
run_time: Service reported execution time.
176+
pending_time: Service reported pending time (0 if service does not report).
177+
stdout: Job's stdout. Call FetchOutput before to ensure it's populated.
178+
stderr: Job's stderr. Call FetchOutput before to ensure it's populated.
179+
fetch_output_fn: Callback expected to return a 2-tuple of str or None whose
180+
values correspond to stdout and stderr respectively. This is called by
181+
FetchOutput which updates stdout and stderr if their respective value in
182+
this callback's return tuple is not None. Defaults to a no-op.
183+
"""
170184

171-
# Service reported execution time
172185
run_time: float
173-
# Service reported pending time (0 if service does not report).
174186
pending_time: float = 0
175-
# Stdout of the job.
176187
stdout: str = ''
177-
# Stderr of the job.
178188
stderr: str = ''
189+
fetch_output_fn: FetchOutputFn = lambda: (None, None)
190+
191+
def FetchOutput(self):
192+
"""Populates stdout and stderr according to fetch_output_fn callback."""
193+
stdout, stderr = self.fetch_output_fn()
194+
if stdout is not None:
195+
self.stdout = stdout
196+
if stderr is not None:
197+
self.stderr = stderr
179198

180199
@property
181200
def wall_time(self) -> float:
@@ -795,7 +814,8 @@ def CheckPrerequisites(self):
795814
if self.cloud == 'AWS' and not aws_flags.AWS_EC2_INSTANCE_PROFILE.value:
796815
raise ValueError(
797816
'EC2 based Spark and Hadoop services require '
798-
'--aws_ec2_instance_profile.')
817+
'--aws_ec2_instance_profile.'
818+
)
799819

800820
def GetClusterCreateTime(self) -> float | None:
801821
"""Returns the cluster creation time.

perfkitbenchmarker/linux_benchmarks/dpb_sparksql_benchmark.py

Lines changed: 124 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
import json
4949
import logging
5050
import os
51+
import re
5152
import time
52-
from typing import List
53+
from typing import Any, List
5354

5455
from absl import flags
5556
from perfkitbenchmarker import configs
@@ -60,6 +61,7 @@
6061
from perfkitbenchmarker import object_storage_service
6162
from perfkitbenchmarker import sample
6263
from perfkitbenchmarker import temp_dir
64+
from perfkitbenchmarker import vm_util
6365

6466
BENCHMARK_NAME = 'dpb_sparksql_benchmark'
6567

@@ -112,9 +114,34 @@
112114
'The record format to use when connecting to BigQuery storage. See: '
113115
'https://github.com/GoogleCloudDataproc/spark-bigquery-connector#properties',
114116
)
117+
_FETCH_RESULTS_FROM_LOGS = flags.DEFINE_bool(
118+
'dpb_sparksql_fetch_results_from_logs',
119+
False,
120+
'Make the query runner script to log query timings to stdout/stderr '
121+
' instead of writing them to some object storage location. Reduces runner '
122+
' latency (and hence its total wall time), but it is not supported by all '
123+
' DPB services.',
124+
)
115125

116126
FLAGS = flags.FLAGS
117127

128+
LOG_RESULTS_PATTERN = (
129+
r'----@spark_sql_runner:results_start@----'
130+
r'(.*)'
131+
r'----@spark_sql_runner:results_end@----'
132+
)
133+
POLL_LOGS_INTERVAL = 60
134+
POLL_LOGS_TIMEOUT = 6 * 60
135+
RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES = (
136+
dpb_constants.DATAPROC_SERVERLESS,
137+
dpb_constants.EMR_SERVERLESS,
138+
dpb_constants.GLUE,
139+
)
140+
141+
142+
class QueryResultsNotReadyError(Exception):
143+
"""Used to signal a job is still running."""
144+
118145

119146
def GetConfig(user_config):
120147
return configs.LoadConfig(BENCHMARK_CONFIG, user_config, BENCHMARK_NAME)
@@ -129,7 +156,6 @@ def CheckPrerequisites(benchmark_config):
129156
Raises:
130157
Config.InvalidValue: On encountering invalid configuration.
131158
"""
132-
del benchmark_config # unused
133159
if not FLAGS.dpb_sparksql_data and FLAGS.dpb_sparksql_create_hive_tables:
134160
raise errors.Config.InvalidValue(
135161
'You must pass dpb_sparksql_data with dpb_sparksql_create_hive_tables'
@@ -160,7 +186,7 @@ def CheckPrerequisites(benchmark_config):
160186
bool(_BIGQUERY_TABLES.value),
161187
bool(FLAGS.dpb_sparksql_database),
162188
])
163-
== 1
189+
> 1
164190
):
165191
logging.warning(
166192
'You should only pass one of them: --dpb_sparksql_data,'
@@ -176,6 +202,16 @@ def CheckPrerequisites(benchmark_config):
176202
'--dpb_sparksql_simultaneous is not compatible with '
177203
'--dpb_sparksql_streams.'
178204
)
205+
if (
206+
_FETCH_RESULTS_FROM_LOGS.value
207+
and benchmark_config.dpb_service.service_type
208+
not in RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES
209+
):
210+
raise errors.Config.InvalidValue(
211+
f'Current dpb service {benchmark_config.dpb_service.service_type!r} is'
212+
' not supported for --dpb_sparksql_fetch_results_from_logs. Supported'
213+
f' dpb services are: {RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES!r}'
214+
)
179215

180216

181217
def Prepare(benchmark_spec):
@@ -275,7 +311,7 @@ def Run(benchmark_spec):
275311
# Run PySpark Spark SQL Runner
276312
report_dir, job_result = _RunQueries(benchmark_spec)
277313

278-
results = _GetQuerySamples(storage_service, report_dir, metadata)
314+
results = _GetQuerySamples(storage_service, report_dir, job_result, metadata)
279315
results += _GetGlobalSamples(results, cluster, job_result, metadata)
280316
results += _GetPrepareSamples(benchmark_spec, metadata)
281317
return results
@@ -319,7 +355,10 @@ def _RunQueries(benchmark_spec) -> tuple[str, dpb_service.JobResult]:
319355
else:
320356
for stream in benchmark_spec.query_streams:
321357
args += ['--sql-queries', ','.join(stream)]
322-
args += ['--report-dir', report_dir]
358+
if _FETCH_RESULTS_FROM_LOGS.value:
359+
args += ['--log-results', 'True']
360+
else:
361+
args += ['--report-dir', report_dir]
323362
if FLAGS.dpb_sparksql_database:
324363
args += ['--database', FLAGS.dpb_sparksql_database]
325364
if FLAGS.dpb_sparksql_create_hive_tables:
@@ -365,46 +404,33 @@ def _RunQueries(benchmark_spec) -> tuple[str, dpb_service.JobResult]:
365404
def _GetQuerySamples(
366405
storage_service: object_storage_service.ObjectStorageService,
367406
report_dir: str,
407+
job_result: dpb_service.JobResult,
368408
base_metadata: MutableMapping[str, str],
369409
) -> list[sample.Sample]:
370-
"""Get Sample objects from metrics storage path."""
371-
# Spark can only write data to directories not files. So do a recursive copy
372-
# of that directory and then search it for the collection of JSON files with
373-
# the results.
374-
temp_run_dir = temp_dir.GetRunDirPath()
375-
storage_service.Copy(report_dir, temp_run_dir, recursive=True)
376-
report_files = []
377-
for dir_name, _, files in os.walk(
378-
os.path.join(temp_run_dir, os.path.basename(report_dir))
379-
):
380-
for filename in files:
381-
if filename.endswith('.json'):
382-
report_file = os.path.join(dir_name, filename)
383-
report_files.append(report_file)
384-
logging.info("Found report file '%s'.", report_file)
385-
if not report_files:
386-
raise errors.Benchmarks.RunError('Job report not found.')
410+
"""Get Sample objects from job's logs."""
411+
412+
if _FETCH_RESULTS_FROM_LOGS.value:
413+
query_results = _FetchResultsFromLogs(job_result)
414+
else:
415+
query_results = _FetchResultsFromStorage(storage_service, report_dir)
387416

388417
samples = []
389-
for report_file in report_files:
390-
with open(report_file) as file:
391-
for line in file:
392-
result = json.loads(line)
393-
logging.info('Timing: %s', result)
394-
query_id = result['query_id']
395-
assert query_id
396-
metadata_copy = base_metadata.copy()
397-
metadata_copy['query'] = query_id
398-
if FLAGS.dpb_sparksql_streams:
399-
metadata_copy['stream'] = result['stream']
400-
samples.append(
401-
sample.Sample(
402-
'sparksql_run_time',
403-
result['duration'],
404-
'seconds',
405-
metadata_copy,
406-
)
418+
for result in query_results:
419+
logging.info('Timing: %s', result)
420+
query_id = result['query_id']
421+
assert query_id
422+
metadata_copy = dict(base_metadata)
423+
metadata_copy['query'] = query_id
424+
if FLAGS.dpb_sparksql_streams:
425+
metadata_copy['stream'] = result['stream']
426+
samples.append(
427+
sample.Sample(
428+
'sparksql_run_time',
429+
result['duration'],
430+
'seconds',
431+
metadata_copy,
407432
)
433+
)
408434
return samples
409435

410436

@@ -524,6 +550,64 @@ def _GetPrepareSamples(
524550
return samples
525551

526552

553+
def _FetchResultsFromStorage(
554+
storage_service: object_storage_service.ObjectStorageService,
555+
report_dir: str,
556+
) -> list[dict[str, Any]]:
557+
"""Get Sample objects from metrics storage path."""
558+
# Spark can only write data to directories not files. So do a recursive copy
559+
# of that directory and then search it for the collection of JSON files with
560+
# the results.
561+
temp_run_dir = temp_dir.GetRunDirPath()
562+
storage_service.Copy(report_dir, temp_run_dir, recursive=True)
563+
report_files = []
564+
for dir_name, _, files in os.walk(
565+
os.path.join(temp_run_dir, os.path.basename(report_dir))
566+
):
567+
for filename in files:
568+
if filename.endswith('.json'):
569+
report_file = os.path.join(dir_name, filename)
570+
report_files.append(report_file)
571+
logging.info("Found report file '%s'.", report_file)
572+
if not report_files:
573+
raise errors.Benchmarks.RunError('Job report not found.')
574+
results = []
575+
for report_file in report_files:
576+
with open(report_file) as file:
577+
for line in file:
578+
results.append(json.loads(line))
579+
return results
580+
581+
582+
@vm_util.Retry(
583+
timeout=POLL_LOGS_TIMEOUT,
584+
poll_interval=POLL_LOGS_INTERVAL,
585+
fuzz=0,
586+
retryable_exceptions=(QueryResultsNotReadyError,),
587+
)
588+
def _FetchResultsFromLogs(job_result: dpb_service.JobResult):
589+
"""Get samples from job results logs."""
590+
job_result.FetchOutput()
591+
logs = '\n'.join([job_result.stdout or '', job_result.stderr])
592+
query_results = _ParseResultsFromLogs(logs)
593+
if query_results is None:
594+
raise QueryResultsNotReadyError
595+
return query_results
596+
597+
598+
def _ParseResultsFromLogs(logs: str) -> list[dict[str, Any]] | None:
599+
json_str_match = re.search(LOG_RESULTS_PATTERN, logs, re.DOTALL)
600+
if not json_str_match:
601+
return None
602+
try:
603+
results = json.loads(json_str_match.group(1))
604+
except ValueError as e:
605+
raise errors.Benchmarks.RunError(
606+
'Corrupted results from logs cannot be deserialized.'
607+
) from e
608+
return results
609+
610+
527611
def _GetDistCpMetadata(base_dir: str, subdirs: List[str], extra_metadata=None):
528612
"""Compute list of table metadata for spark_sql_distcp metadata flags."""
529613
metadata = []

0 commit comments

Comments
 (0)