Skip to content

Commit 0096386

Browse files
dorellangcopybara-github
authored andcommitted
Bundle queries in the script runner instead of loading them with Spark API.
PiperOrigin-RevId: 715850728
1 parent 1dca42d commit 0096386

File tree

4 files changed

+170
-112
lines changed

4 files changed

+170
-112
lines changed

perfkitbenchmarker/dpb_sparksql_benchmark_helper.py

Lines changed: 89 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import json
1717
import os
1818
import re
19-
import typing
2019
from typing import Any
2120

2221
from absl import flags
@@ -130,15 +129,20 @@
130129
FLAGS = flags.FLAGS
131130

132131
SCRIPT_DIR = 'spark_sql_test_scripts'
132+
QUERIES_PY_BASENAME = 'spark_sql_queries.py'
133133
SPARK_SQL_DISTCP_SCRIPT = os.path.join(SCRIPT_DIR, 'spark_sql_distcp.py')
134134
# Creates spark table using pyspark by loading the parquet data.
135135
# Args:
136136
# argv[1]: string, The table name in the dataset that this script will create.
137137
# argv[2]: string, The data path of the table.
138138
SPARK_TABLE_SCRIPT = os.path.join(SCRIPT_DIR, 'spark_table.py')
139139
SPARK_SQL_RUNNER_SCRIPT = os.path.join(SCRIPT_DIR, 'spark_sql_runner.py')
140+
QUERIES_SCRIPT = os.path.join(SCRIPT_DIR, QUERIES_PY_BASENAME)
140141
SPARK_SQL_PERF_GIT = 'https://github.com/databricks/spark-sql-perf.git'
141142
SPARK_SQL_PERF_GIT_COMMIT = '6b2bf9f9ad6f6c2f620062fda78cded203f619c8'
143+
QUERIES_SUB_PATTERN = (
144+
r'^#[\s]*spark_sql_queries:start[\s]*$.*^#[\s]*spark_sql_queries:end[\s]*$'
145+
)
142146

143147

144148
def GetStreams() -> list[list[str]]:
@@ -183,19 +187,20 @@ def Prepare(benchmark_spec):
183187
"""Copies scripts and all the queries to cloud."""
184188
cluster = benchmark_spec.dpb_service
185189
storage_service = cluster.storage_service
186-
benchmark_spec.query_dir = LoadAndStageQueries(
187-
storage_service, cluster.base_dir
188-
)
190+
queries = _FetchQueryContents(storage_service)
191+
rendered_runner_filepath = _RenderRunnerScriptWithQueries(queries)
189192
benchmark_spec.query_streams = GetStreams()
190193

191-
scripts_to_upload = [
192-
SPARK_SQL_DISTCP_SCRIPT,
193-
SPARK_TABLE_SCRIPT,
194-
SPARK_SQL_RUNNER_SCRIPT,
195-
] + cluster.GetServiceWrapperScriptsToUpload()
196-
for script in scripts_to_upload:
197-
src_url = data.ResourcePath(script)
198-
storage_service.CopyToBucket(src_url, cluster.bucket, script)
194+
scripts_to_upload = {
195+
data.ResourcePath(SPARK_SQL_DISTCP_SCRIPT): SPARK_SQL_DISTCP_SCRIPT,
196+
data.ResourcePath(SPARK_TABLE_SCRIPT): SPARK_TABLE_SCRIPT,
197+
rendered_runner_filepath: SPARK_SQL_RUNNER_SCRIPT,
198+
}
199+
service_scripts = cluster.GetServiceWrapperScriptsToUpload()
200+
for script in service_scripts:
201+
scripts_to_upload[data.ResourcePath(script)] = script
202+
for local_path, bucket_dest in scripts_to_upload.items():
203+
storage_service.CopyToBucket(local_path, cluster.bucket, bucket_dest)
199204

200205
benchmark_spec.table_subdirs = []
201206
benchmark_spec.data_dir = None
@@ -214,67 +219,69 @@ def Prepare(benchmark_spec):
214219
benchmark_spec.data_dir = FLAGS.dpb_sparksql_data
215220

216221

217-
def LoadAndStageQueries(
218-
storage_service: object_storage_service.ObjectStorageService, base_dir: str
219-
) -> str:
220-
"""Loads queries stages them in object storage if needed.
222+
def _FetchQueryContents(
223+
storage_service: object_storage_service.ObjectStorageService,
224+
) -> dict[str, str]:
225+
"""Fetches query contents into a dict from a source depending on flags passed.
221226
222-
Queries are selected using --dpb_sparksql_query and --dpb_sparksql_order.
227+
Queries are selected using --dpb_sparksql_query, --dpb_sparksql_order and
228+
--dpb_sparksql_queries_url.
223229
224230
Args:
225231
storage_service: object_storage_service to stage queries into.
226-
base_dir: object storage directory to stage queries into.
227232
228233
Returns:
229-
The object storage path where the queries are staged into.
234+
A dict where the key corresponds to the query ID and value to the actual
235+
query SQL.
230236
231237
Raises:
232238
PrepareException if a requested query is not found.
233239
"""
234-
235240
if _QUERIES_URL.value:
236-
_GetQueryFilesFromUrl(storage_service, _QUERIES_URL.value)
237-
# casting it, so it doesn't complain about being str | None
238-
return typing.cast(str, _QUERIES_URL.value)
239-
_StageQueriesFromRepo(storage_service, base_dir)
240-
return base_dir
241+
return _FetchQueryFilesFromUrl(storage_service, _QUERIES_URL.value)
242+
return _FetchQueriesFromRepo()
241243

242244

243-
def _GetQueryFilesFromUrl(
245+
def _FetchQueryFilesFromUrl(
244246
storage_service: object_storage_service.ObjectStorageService,
245247
queries_url: str,
246-
) -> None:
248+
) -> dict[str, str]:
247249
"""Checks if relevant query files from queries_url exist.
248250
249251
Args:
250-
storage_service: object_storage_service to list query files.
252+
storage_service: object_storage_service to fetch query files.
251253
queries_url: Object Storage directory URL where the benchmark queries are
252254
contained.
255+
256+
Returns:
257+
A dict where the key corresponds to the query ID and value to the actual
258+
query SQL.
253259
"""
260+
temp_run_dir = temp_dir.GetRunDirPath()
261+
spark_sql_queries = os.path.join(temp_run_dir, 'spark_sql_queries')
254262
query_paths = {q: os.path.join(queries_url, q) for q in GetQueryIdsToStage()}
263+
queries = {}
255264
queries_missing = set()
256265
for q in query_paths:
257266
try:
258-
storage_service.List(query_paths[q])
267+
local_path = os.path.join(spark_sql_queries, q)
268+
storage_service.Copy(query_paths[q], os.path.join(spark_sql_queries, q))
269+
with open(local_path) as f:
270+
queries[q] = f.read()
259271
except errors.VmUtil.IssueCommandError: # Handling query not found
260272
queries_missing.add(q)
261273
if queries_missing:
262274
raise errors.Benchmarks.PrepareException(
263275
'Could not find queries {}'.format(', '.join(sorted(queries_missing)))
264276
)
277+
return queries
265278

266279

267-
def _StageQueriesFromRepo(
268-
storage_service: object_storage_service.ObjectStorageService, base_dir: str
269-
) -> None:
270-
"""Copies queries from default Github repo to object storage.
271-
272-
Args:
273-
storage_service: object_storage_service to stage queries into.
274-
base_dir: object storage directory to stage queries into.
275-
"""
280+
def _FetchQueriesFromRepo() -> dict[str, str]:
281+
"""Fetches queries from default Github repo to object storage."""
276282
temp_run_dir = temp_dir.GetRunDirPath()
277283
spark_sql_perf_dir = os.path.join(temp_run_dir, 'spark_sql_perf_dir')
284+
queries = {}
278285

279286
# Clone repo
280287
vm_util.IssueCommand(['git', 'clone', SPARK_SQL_PERF_GIT, spark_sql_perf_dir])
@@ -286,23 +293,59 @@ def _StageQueriesFromRepo(
286293
)
287294

288295
# Search repo for queries
289-
query_file = {} # map query -> staged file
290296
queries_to_stage = GetQueryIdsToStage()
291297
for dir_name, _, files in os.walk(query_dir):
292298
for filename in files:
293299
query_id = GetQueryId(filename)
294300
if query_id:
295-
# only upload specified queries
301+
# only load specified queries
296302
if query_id in queries_to_stage:
297-
src_file = os.path.join(dir_name, filename)
298-
staged_filename = query_id
299-
staged_file = '{}/{}'.format(base_dir, staged_filename)
300-
storage_service.Copy(src_file, staged_file)
301-
query_file[query_id] = staged_file
303+
query_path = os.path.join(dir_name, filename)
304+
with open(query_path) as f:
305+
queries[query_id] = f.read()
302306

303307
# Validate all requested queries are present.
304-
missing_queries = set(queries_to_stage) - set(query_file.keys())
308+
missing_queries = set(queries_to_stage) - set(queries.keys())
305309
if missing_queries:
306310
raise errors.Benchmarks.PrepareException(
307311
'Could not find queries {}'.format(missing_queries)
308312
)
313+
314+
return queries
315+
316+
317+
def _RenderRunnerScriptWithQueries(queries: dict[str, str]) -> str:
318+
"""Renders a Spark SQL runner file with a dict having all queries to be run.
319+
320+
The dict will be in the QUERIES variable inside the region delimited by
321+
"# spark_sql_queries:start" and "# spark_sql_queries:end" in the original
322+
script source.
323+
324+
Args:
325+
queries: A dict where each key corresponds to the query ID and each value to
326+
the actual query SQL.
327+
328+
Returns:
329+
The python file rendered as a str.
330+
"""
331+
lines = ['QUERIES = {']
332+
for query_id, sql_str in queries.items():
333+
lines.append(f' {query_id!r}: {sql_str!r},')
334+
lines.append('}')
335+
queries_dict_source = ('\n'.join(lines) + '\n').replace('\\', '\\\\')
336+
temp_run_dir = temp_dir.GetRunDirPath()
337+
queries_py_filepath = os.path.join(
338+
temp_run_dir, os.path.basename(SPARK_SQL_RUNNER_SCRIPT)
339+
)
340+
with open(data.ResourcePath(SPARK_SQL_RUNNER_SCRIPT)) as f:
341+
runner_source = f.read()
342+
contents = re.sub(
343+
QUERIES_SUB_PATTERN,
344+
queries_dict_source,
345+
runner_source,
346+
1,
347+
re.MULTILINE | re.DOTALL,
348+
)
349+
with open(queries_py_filepath, 'w') as f:
350+
f.write(contents)
351+
return queries_py_filepath

perfkitbenchmarker/linux_benchmarks/dpb_sparksql_benchmark.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,14 @@ def CheckPrerequisites(benchmark_config):
154154
' --dpb_sparksql_bigquery_tables, or dpb_sparksql_database. You will'
155155
' probably not have data to query!'
156156
)
157-
if sum([
158-
bool(FLAGS.dpb_sparksql_data),
159-
bool(_BIGQUERY_TABLES.value),
160-
bool(FLAGS.dpb_sparksql_database),
161-
]) == 1:
157+
if (
158+
sum([
159+
bool(FLAGS.dpb_sparksql_data),
160+
bool(_BIGQUERY_TABLES.value),
161+
bool(FLAGS.dpb_sparksql_database),
162+
])
163+
== 1
164+
):
162165
logging.warning(
163166
'You should only pass one of them: --dpb_sparksql_data,'
164167
' --dpb_sparksql_bigquery_tables, or --dpb_sparksql_database.'
@@ -306,16 +309,16 @@ def _RunQueries(benchmark_spec) -> tuple[str, dpb_service.JobResult]:
306309
"""Runs queries. Returns storage path with metrics and JobResult object."""
307310
cluster = benchmark_spec.dpb_service
308311
report_dir = '/'.join([cluster.base_dir, f'report-{int(time.time()*1000)}'])
309-
args = ['--sql-scripts-dir', benchmark_spec.query_dir]
312+
args = []
310313
if FLAGS.dpb_sparksql_simultaneous:
311314
# Assertion true bc of --dpb_sparksql_simultaneous and
312315
# --dpb_sparksql_streams being mutually exclusive.
313316
assert len(benchmark_spec.query_streams) == 1
314317
for query in benchmark_spec.query_streams[0]:
315-
args += ['--sql-scripts', query]
318+
args += ['--sql-queries', query]
316319
else:
317320
for stream in benchmark_spec.query_streams:
318-
args += ['--sql-scripts', ','.join(stream)]
321+
args += ['--sql-queries', ','.join(stream)]
319322
args += ['--report-dir', report_dir]
320323
if FLAGS.dpb_sparksql_database:
321324
args += ['--database', FLAGS.dpb_sparksql_database]
@@ -388,7 +391,7 @@ def _GetQuerySamples(
388391
for line in file:
389392
result = json.loads(line)
390393
logging.info('Timing: %s', result)
391-
query_id = dpb_sparksql_benchmark_helper.GetQueryId(result['script'])
394+
query_id = result['query_id']
392395
assert query_id
393396
metadata_copy = base_metadata.copy()
394397
metadata_copy['query'] = query_id

0 commit comments

Comments
 (0)