1616import json
1717import os
1818import re
19- import typing
2019from typing import Any
2120
2221from absl import flags
130129FLAGS = flags .FLAGS
131130
132131SCRIPT_DIR = 'spark_sql_test_scripts'
132+ QUERIES_PY_BASENAME = 'spark_sql_queries.py'
133133SPARK_SQL_DISTCP_SCRIPT = os .path .join (SCRIPT_DIR , 'spark_sql_distcp.py' )
134134# Creates spark table using pyspark by loading the parquet data.
135135# Args:
136136# argv[1]: string, The table name in the dataset that this script will create.
137137# argv[2]: string, The data path of the table.
138138SPARK_TABLE_SCRIPT = os .path .join (SCRIPT_DIR , 'spark_table.py' )
139139SPARK_SQL_RUNNER_SCRIPT = os .path .join (SCRIPT_DIR , 'spark_sql_runner.py' )
140+ QUERIES_SCRIPT = os .path .join (SCRIPT_DIR , QUERIES_PY_BASENAME )
140141SPARK_SQL_PERF_GIT = 'https://github.com/databricks/spark-sql-perf.git'
141142SPARK_SQL_PERF_GIT_COMMIT = '6b2bf9f9ad6f6c2f620062fda78cded203f619c8'
143+ QUERIES_SUB_PATTERN = (
144+ r'^#[\s]*spark_sql_queries:start[\s]*$.*^#[\s]*spark_sql_queries:end[\s]*$'
145+ )
142146
143147
144148def GetStreams () -> list [list [str ]]:
@@ -183,19 +187,20 @@ def Prepare(benchmark_spec):
183187 """Copies scripts and all the queries to cloud."""
184188 cluster = benchmark_spec .dpb_service
185189 storage_service = cluster .storage_service
186- benchmark_spec .query_dir = LoadAndStageQueries (
187- storage_service , cluster .base_dir
188- )
190+ queries = _FetchQueryContents (storage_service )
191+ rendered_runner_filepath = _RenderRunnerScriptWithQueries (queries )
189192 benchmark_spec .query_streams = GetStreams ()
190193
191- scripts_to_upload = [
192- SPARK_SQL_DISTCP_SCRIPT ,
193- SPARK_TABLE_SCRIPT ,
194- SPARK_SQL_RUNNER_SCRIPT ,
195- ] + cluster .GetServiceWrapperScriptsToUpload ()
196- for script in scripts_to_upload :
197- src_url = data .ResourcePath (script )
198- storage_service .CopyToBucket (src_url , cluster .bucket , script )
194+ scripts_to_upload = {
195+ data .ResourcePath (SPARK_SQL_DISTCP_SCRIPT ): SPARK_SQL_DISTCP_SCRIPT ,
196+ data .ResourcePath (SPARK_TABLE_SCRIPT ): SPARK_TABLE_SCRIPT ,
197+ rendered_runner_filepath : SPARK_SQL_RUNNER_SCRIPT ,
198+ }
199+ service_scripts = cluster .GetServiceWrapperScriptsToUpload ()
200+ for script in service_scripts :
201+ scripts_to_upload [data .ResourcePath (script )] = script
202+ for local_path , bucket_dest in scripts_to_upload .items ():
203+ storage_service .CopyToBucket (local_path , cluster .bucket , bucket_dest )
199204
200205 benchmark_spec .table_subdirs = []
201206 benchmark_spec .data_dir = None
@@ -214,67 +219,69 @@ def Prepare(benchmark_spec):
214219 benchmark_spec .data_dir = FLAGS .dpb_sparksql_data
215220
216221
217- def LoadAndStageQueries (
218- storage_service : object_storage_service .ObjectStorageService , base_dir : str
219- ) -> str :
220- """Loads queries stages them in object storage if needed .
222+ def _FetchQueryContents (
223+ storage_service : object_storage_service .ObjectStorageService ,
224+ ) -> dict [ str , str ] :
225+ """Fetches query contents into a dict from a source depending on flags passed .
221226
222- Queries are selected using --dpb_sparksql_query and --dpb_sparksql_order.
227+ Queries are selected using --dpb_sparksql_query, --dpb_sparksql_order and
228+ --dpb_sparksql_queries_url.
223229
224230 Args:
225231 storage_service: object_storage_service to stage queries into.
226- base_dir: object storage directory to stage queries into.
227232
228233 Returns:
229- The object storage path where the queries are staged into.
234+ A dict where the key corresponds to the query ID and value to the actual
235+ query SQL.
230236
231237 Raises:
232238 PrepareException if a requested query is not found.
233239 """
234-
235240 if _QUERIES_URL .value :
236- _GetQueryFilesFromUrl (storage_service , _QUERIES_URL .value )
237- # casting it, so it doesn't complain about being str | None
238- return typing .cast (str , _QUERIES_URL .value )
239- _StageQueriesFromRepo (storage_service , base_dir )
240- return base_dir
241+ return _FetchQueryFilesFromUrl (storage_service , _QUERIES_URL .value )
242+ return _FetchQueriesFromRepo ()
241243
242244
243- def _GetQueryFilesFromUrl (
245+ def _FetchQueryFilesFromUrl (
244246 storage_service : object_storage_service .ObjectStorageService ,
245247 queries_url : str ,
246- ) -> None :
248+ ) -> dict [ str , str ] :
247249 """Checks if relevant query files from queries_url exist.
248250
249251 Args:
250- storage_service: object_storage_service to list query files.
252+ storage_service: object_storage_service to fetch query files.
251253 queries_url: Object Storage directory URL where the benchmark queries are
252254 contained.
255+
256+ Returns:
257+ A dict where the key corresponds to the query ID and value to the actual
258+ query SQL.
253259 """
260+ temp_run_dir = temp_dir .GetRunDirPath ()
261+ spark_sql_queries = os .path .join (temp_run_dir , 'spark_sql_queries' )
254262 query_paths = {q : os .path .join (queries_url , q ) for q in GetQueryIdsToStage ()}
263+ queries = {}
255264 queries_missing = set ()
256265 for q in query_paths :
257266 try :
258- storage_service .List (query_paths [q ])
267+ local_path = os .path .join (spark_sql_queries , q )
268+ storage_service .Copy (query_paths [q ], os .path .join (spark_sql_queries , q ))
269+ with open (local_path ) as f :
270+ queries [q ] = f .read ()
259271 except errors .VmUtil .IssueCommandError : # Handling query not found
260272 queries_missing .add (q )
261273 if queries_missing :
262274 raise errors .Benchmarks .PrepareException (
263275 'Could not find queries {}' .format (', ' .join (sorted (queries_missing )))
264276 )
277+ return queries
265278
266279
267- def _StageQueriesFromRepo (
268- storage_service : object_storage_service .ObjectStorageService , base_dir : str
269- ) -> None :
270- """Copies queries from default Github repo to object storage.
271-
272- Args:
273- storage_service: object_storage_service to stage queries into.
274- base_dir: object storage directory to stage queries into.
275- """
280+ def _FetchQueriesFromRepo () -> dict [str , str ]:
281+ """Fetches queries from default Github repo to object storage."""
276282 temp_run_dir = temp_dir .GetRunDirPath ()
277283 spark_sql_perf_dir = os .path .join (temp_run_dir , 'spark_sql_perf_dir' )
284+ queries = {}
278285
279286 # Clone repo
280287 vm_util .IssueCommand (['git' , 'clone' , SPARK_SQL_PERF_GIT , spark_sql_perf_dir ])
@@ -286,23 +293,59 @@ def _StageQueriesFromRepo(
286293 )
287294
288295 # Search repo for queries
289- query_file = {} # map query -> staged file
290296 queries_to_stage = GetQueryIdsToStage ()
291297 for dir_name , _ , files in os .walk (query_dir ):
292298 for filename in files :
293299 query_id = GetQueryId (filename )
294300 if query_id :
295- # only upload specified queries
301+ # only load specified queries
296302 if query_id in queries_to_stage :
297- src_file = os .path .join (dir_name , filename )
298- staged_filename = query_id
299- staged_file = '{}/{}' .format (base_dir , staged_filename )
300- storage_service .Copy (src_file , staged_file )
301- query_file [query_id ] = staged_file
303+ query_path = os .path .join (dir_name , filename )
304+ with open (query_path ) as f :
305+ queries [query_id ] = f .read ()
302306
303307 # Validate all requested queries are present.
304- missing_queries = set (queries_to_stage ) - set (query_file .keys ())
308+ missing_queries = set (queries_to_stage ) - set (queries .keys ())
305309 if missing_queries :
306310 raise errors .Benchmarks .PrepareException (
307311 'Could not find queries {}' .format (missing_queries )
308312 )
313+
314+ return queries
315+
316+
317+ def _RenderRunnerScriptWithQueries (queries : dict [str , str ]) -> str :
318+ """Renders a Spark SQL runner file with a dict having all queries to be run.
319+
320+ The dict will be in the QUERIES variable inside the region delimited by
321+ "# spark_sql_queries:start" and "# spark_sql_queries:end" in the original
322+ script source.
323+
324+ Args:
325+ queries: A dict where each key corresponds to the query ID and each value to
326+ the actual query SQL.
327+
328+ Returns:
329+ The python file rendered as a str.
330+ """
331+ lines = ['QUERIES = {' ]
332+ for query_id , sql_str in queries .items ():
333+ lines .append (f' { query_id !r} : { sql_str !r} ,' )
334+ lines .append ('}' )
335+ queries_dict_source = ('\n ' .join (lines ) + '\n ' ).replace ('\\ ' , '\\ \\ ' )
336+ temp_run_dir = temp_dir .GetRunDirPath ()
337+ queries_py_filepath = os .path .join (
338+ temp_run_dir , os .path .basename (SPARK_SQL_RUNNER_SCRIPT )
339+ )
340+ with open (data .ResourcePath (SPARK_SQL_RUNNER_SCRIPT )) as f :
341+ runner_source = f .read ()
342+ contents = re .sub (
343+ QUERIES_SUB_PATTERN ,
344+ queries_dict_source ,
345+ runner_source ,
346+ 1 ,
347+ re .MULTILINE | re .DOTALL ,
348+ )
349+ with open (queries_py_filepath , 'w' ) as f :
350+ f .write (contents )
351+ return queries_py_filepath
0 commit comments