Skip to content

Commit 0eb1cbe

Browse files
dorellangcopybara-github
authored andcommitted
Test table metadata file read by pyspark script to create temporary view tables.
PiperOrigin-RevId: 715021916
1 parent 7a2987d commit 0eb1cbe

File tree

2 files changed

+278
-5
lines changed

2 files changed

+278
-5
lines changed

perfkitbenchmarker/scripts/spark_sql_test_scripts/spark_sql_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def parse_args(args=None):
9292
return parser.parse_args(args)
9393

9494

95-
def load_file(spark, object_path):
95+
def _load_file(spark, object_path):
9696
"""Load an HCFS file into a string."""
9797
return '\n'.join(spark.sparkContext.textFile(object_path).collect())
9898

@@ -110,7 +110,7 @@ def main(args):
110110
spark.catalog.setCurrentDatabase(args.database)
111111
table_metadata = []
112112
if args.table_metadata:
113-
table_metadata = json.loads(load_file(spark, args.table_metadata)).items()
113+
table_metadata = get_table_metadata(spark, args).items()
114114
for name, (fmt, options) in table_metadata:
115115
logging.info('Loading %s', name)
116116
spark.read.format(fmt).options(**options).load().createTempView(name)
@@ -170,6 +170,11 @@ def get_script_streams(args):
170170
]
171171

172172

173+
def get_table_metadata(spark, args):
174+
"""Gets table metadata to create temporary views."""
175+
return json.loads(_load_file(spark, args.table_metadata))
176+
177+
173178
def run_sql_script(
174179
spark_session, script_stream, stream_id, raise_query_execution_errors
175180
):
@@ -178,7 +183,7 @@ def run_sql_script(
178183
results = []
179184
for script in script_stream:
180185
# Read script from object storage using rdd API
181-
query = load_file(spark_session, script)
186+
query = _load_file(spark_session, script)
182187

183188
try:
184189
logging.info('Running %s', script)

tests/linux_benchmarks/dpb_sparksql_benchmark_test.py

Lines changed: 270 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,35 @@
1-
"""Tests for dpb_sparksql_benchmark."""
2-
1+
import json
2+
import sys
3+
from typing import Any
34
import unittest
45
from unittest import mock
56

67
from absl.testing import flagsaver
8+
from absl.testing import parameterized
79
import freezegun
810
from perfkitbenchmarker import dpb_sparksql_benchmark_helper
911
from perfkitbenchmarker.linux_benchmarks import dpb_sparksql_benchmark
1012
from tests import pkb_common_test_case
1113

14+
PY4J_MOCK = mock.Mock()
15+
PYSPARK_MOCK = mock.Mock()
16+
sys.modules['py4j'] = PY4J_MOCK
17+
sys.modules['pyspark'] = PYSPARK_MOCK
18+
19+
from perfkitbenchmarker.scripts.spark_sql_test_scripts import spark_sql_runner
20+
21+
22+
_TPCH_TABLES = [
23+
'customer',
24+
'lineitem',
25+
'nation',
26+
'orders',
27+
'part',
28+
'partsupp',
29+
'region',
30+
'supplier',
31+
]
32+
1233

1334
class DpbSparksqlBenchmarkTest(pkb_common_test_case.PkbCommonTestCase):
1435

@@ -99,6 +120,253 @@ def testRunQueriesSimultaneous(self):
99120
],
100121
)
101122

123+
def SetupTableMetadataMocks(self):
124+
staged_metadata: str | None = None
125+
126+
def _FakeStageMetadata(
127+
table_metadata: dict[Any, Any],
128+
storage_service: Any,
129+
table_metadata_file: Any,
130+
):
131+
nonlocal staged_metadata
132+
del storage_service, table_metadata_file
133+
staged_metadata = json.dumps(table_metadata)
134+
135+
stage_metadata_mock = self.enter_context(
136+
mock.patch.object(
137+
dpb_sparksql_benchmark_helper,
138+
'StageMetadata',
139+
side_effect=_FakeStageMetadata,
140+
)
141+
)
142+
spark_sql_runner_mock = self.enter_context(
143+
mock.patch.object(
144+
spark_sql_runner,
145+
'_load_file',
146+
side_effect=lambda *args, **kwargs: staged_metadata,
147+
)
148+
)
149+
return stage_metadata_mock, spark_sql_runner_mock
150+
151+
@parameterized.named_parameters(
152+
dict(testcase_name='Default', extra_flags={}, want_format='parquet'),
153+
dict(
154+
testcase_name='OrcFormat',
155+
extra_flags={'dpb_sparksql_data_format': 'orc'},
156+
want_format='orc',
157+
),
158+
)
159+
@flagsaver.flagsaver(dpb_sparksql_order=['1', '2', '3'])
160+
def testRunnerScriptGetTableMetadata(self, extra_flags, want_format):
161+
# Arrange
162+
stage_metadata_mock, spark_sql_runner_mock = self.SetupTableMetadataMocks()
163+
self.benchmark_spec_mock.query_dir = 'gs://test'
164+
self.benchmark_spec_mock.data_dir = 'gs://datasetbucket'
165+
self.benchmark_spec_mock.query_streams = (
166+
dpb_sparksql_benchmark_helper.GetStreams()
167+
)
168+
self.benchmark_spec_mock.table_subdirs = list(_TPCH_TABLES)
169+
if extra_flags:
170+
self.enter_context(flagsaver.flagsaver(**extra_flags))
171+
172+
# Act
173+
dpb_sparksql_benchmark._RunQueries(self.benchmark_spec_mock)
174+
table_metadata = spark_sql_runner.get_table_metadata(
175+
mock.MagicMock(), mock.MagicMock()
176+
)
177+
178+
# Assert
179+
stage_metadata_mock.assert_called_once()
180+
spark_sql_runner_mock.assert_called_once()
181+
self.assertEqual(
182+
table_metadata,
183+
{
184+
'customer': [want_format, {'path': 'gs://datasetbucket/customer'}],
185+
'lineitem': [want_format, {'path': 'gs://datasetbucket/lineitem'}],
186+
'nation': [want_format, {'path': 'gs://datasetbucket/nation'}],
187+
'orders': [want_format, {'path': 'gs://datasetbucket/orders'}],
188+
'part': [want_format, {'path': 'gs://datasetbucket/part'}],
189+
'partsupp': [want_format, {'path': 'gs://datasetbucket/partsupp'}],
190+
'region': [want_format, {'path': 'gs://datasetbucket/region'}],
191+
'supplier': [want_format, {'path': 'gs://datasetbucket/supplier'}],
192+
},
193+
)
194+
195+
@parameterized.named_parameters(
196+
dict(testcase_name='Default', extra_flags={}, want_delim=','),
197+
dict(
198+
testcase_name='PipeDelim',
199+
extra_flags={'dpb_sparksql_csv_delimiter': '|'},
200+
want_delim='|',
201+
),
202+
)
203+
@flagsaver.flagsaver(
204+
dpb_sparksql_order=['1', '2', '3'], dpb_sparksql_data_format='csv'
205+
)
206+
def testRunnerScriptGetTableMetadataCsv(self, extra_flags, want_delim):
207+
# Arrange
208+
stage_metadata_mock, spark_sql_runner_mock = self.SetupTableMetadataMocks()
209+
self.benchmark_spec_mock.query_dir = 'gs://test'
210+
self.benchmark_spec_mock.data_dir = 'gs://datasetbucket'
211+
self.benchmark_spec_mock.query_streams = (
212+
dpb_sparksql_benchmark_helper.GetStreams()
213+
)
214+
self.benchmark_spec_mock.table_subdirs = list(_TPCH_TABLES)
215+
if extra_flags:
216+
self.enter_context(flagsaver.flagsaver(**extra_flags))
217+
218+
# Act
219+
dpb_sparksql_benchmark._RunQueries(self.benchmark_spec_mock)
220+
table_metadata = spark_sql_runner.get_table_metadata(
221+
mock.MagicMock(), mock.MagicMock()
222+
)
223+
224+
# Assert
225+
stage_metadata_mock.assert_called_once()
226+
spark_sql_runner_mock.assert_called_once()
227+
228+
self.assertEqual(
229+
table_metadata,
230+
{
231+
'customer': [
232+
'csv',
233+
{
234+
'path': 'gs://datasetbucket/customer',
235+
'header': 'true',
236+
'delimiter': want_delim,
237+
},
238+
],
239+
'lineitem': [
240+
'csv',
241+
{
242+
'path': 'gs://datasetbucket/lineitem',
243+
'header': 'true',
244+
'delimiter': want_delim,
245+
},
246+
],
247+
'nation': [
248+
'csv',
249+
{
250+
'path': 'gs://datasetbucket/nation',
251+
'header': 'true',
252+
'delimiter': want_delim,
253+
},
254+
],
255+
'orders': [
256+
'csv',
257+
{
258+
'path': 'gs://datasetbucket/orders',
259+
'header': 'true',
260+
'delimiter': want_delim,
261+
},
262+
],
263+
'part': [
264+
'csv',
265+
{
266+
'path': 'gs://datasetbucket/part',
267+
'header': 'true',
268+
'delimiter': want_delim,
269+
},
270+
],
271+
'partsupp': [
272+
'csv',
273+
{
274+
'path': 'gs://datasetbucket/partsupp',
275+
'header': 'true',
276+
'delimiter': want_delim,
277+
},
278+
],
279+
'region': [
280+
'csv',
281+
{
282+
'path': 'gs://datasetbucket/region',
283+
'header': 'true',
284+
'delimiter': want_delim,
285+
},
286+
],
287+
'supplier': [
288+
'csv',
289+
{
290+
'path': 'gs://datasetbucket/supplier',
291+
'header': 'true',
292+
'delimiter': want_delim,
293+
},
294+
],
295+
},
296+
)
297+
298+
@flagsaver.flagsaver(
299+
dpb_sparksql_order=['1', '2', '3'],
300+
bigquery_tables=[
301+
'tpcds_1t.customer',
302+
'tpcds_1t.lineitem',
303+
'tpcds_1t.nation',
304+
'tpcds_1t.orders',
305+
'tpcds_1t.part',
306+
'tpcds_1t.partsupp',
307+
'tpcds_1t.region',
308+
'tpcds_1t.supplier',
309+
],
310+
bigquery_record_format='ARROW',
311+
dpb_sparksql_data_format='com.google.cloud.spark.bigquery',
312+
)
313+
def testRunnerScriptGetTableMetadataBigQuery(self):
314+
# Arrange
315+
stage_metadata_mock, spark_sql_runner_mock = self.SetupTableMetadataMocks()
316+
self.benchmark_spec_mock.query_dir = 'gs://test'
317+
self.benchmark_spec_mock.data_dir = 'gs://datasetbucket'
318+
self.benchmark_spec_mock.query_streams = (
319+
dpb_sparksql_benchmark_helper.GetStreams()
320+
)
321+
self.benchmark_spec_mock.table_subdirs = list(_TPCH_TABLES)
322+
323+
# Act
324+
dpb_sparksql_benchmark._RunQueries(self.benchmark_spec_mock)
325+
table_metadata = spark_sql_runner.get_table_metadata(
326+
mock.MagicMock(), mock.MagicMock()
327+
)
328+
329+
# Assert
330+
stage_metadata_mock.assert_called_once()
331+
spark_sql_runner_mock.assert_called_once()
332+
self.assertEqual(
333+
table_metadata,
334+
{
335+
'customer': [
336+
'com.google.cloud.spark.bigquery',
337+
{'table': 'tpcds_1t.customer', 'readDataFormat': 'ARROW'},
338+
],
339+
'lineitem': [
340+
'com.google.cloud.spark.bigquery',
341+
{'table': 'tpcds_1t.lineitem', 'readDataFormat': 'ARROW'},
342+
],
343+
'nation': [
344+
'com.google.cloud.spark.bigquery',
345+
{'table': 'tpcds_1t.nation', 'readDataFormat': 'ARROW'},
346+
],
347+
'orders': [
348+
'com.google.cloud.spark.bigquery',
349+
{'table': 'tpcds_1t.orders', 'readDataFormat': 'ARROW'},
350+
],
351+
'part': [
352+
'com.google.cloud.spark.bigquery',
353+
{'table': 'tpcds_1t.part', 'readDataFormat': 'ARROW'},
354+
],
355+
'partsupp': [
356+
'com.google.cloud.spark.bigquery',
357+
{'table': 'tpcds_1t.partsupp', 'readDataFormat': 'ARROW'},
358+
],
359+
'region': [
360+
'com.google.cloud.spark.bigquery',
361+
{'table': 'tpcds_1t.region', 'readDataFormat': 'ARROW'},
362+
],
363+
'supplier': [
364+
'com.google.cloud.spark.bigquery',
365+
{'table': 'tpcds_1t.supplier', 'readDataFormat': 'ARROW'},
366+
],
367+
},
368+
)
369+
102370

103371
if __name__ == '__main__':
104372
unittest.main()

0 commit comments

Comments
 (0)