2121from pyspark .sql import SparkSession
2222import time
2323
24- def main (benchmark : str , data_path : str , query_path : str , iterations : int , output : str , name : str ):
24+ def main (benchmark : str , data_path : str , query_path : str , iterations : int , output : str , name : str , query_num : int = None ):
2525
2626 # Initialize a SparkSession
2727 spark = SparkSession .builder \
@@ -59,9 +59,17 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu
5959
6060 for iteration in range (0 , iterations ):
6161 print (f"Starting iteration { iteration } of { iterations } " )
62- iter_start_time = time .time ()
6362
64- for query in range (1 , num_queries + 1 ):
63+ # Determine which queries to run
64+ if query_num is not None :
65+ # Validate query number
66+ if query_num < 1 or query_num > num_queries :
67+ raise ValueError (f"Query number { query_num } is out of range. Valid range is 1-{ num_queries } for { benchmark } " )
68+ queries_to_run = [query_num ]
69+ else :
70+ queries_to_run = range (1 , num_queries + 1 )
71+
72+ for query in queries_to_run :
6573 spark .sparkContext .setJobDescription (f"{ benchmark } q{ query } " )
6674
6775 # read text file
@@ -105,8 +113,6 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu
105113 # Stop the SparkSession
106114 spark .stop ()
107115
108- #print(str)
109-
110116if __name__ == "__main__" :
111117 parser = argparse .ArgumentParser (description = "DataFusion benchmark derived from TPC-H / TPC-DS" )
112118 parser .add_argument ("--benchmark" , required = True , help = "Benchmark to run (tpch or tpcds)" )
@@ -115,6 +121,7 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu
115121 parser .add_argument ("--iterations" , required = False , default = "1" , help = "How many iterations to run" )
116122 parser .add_argument ("--output" , required = True , help = "Path to write output" )
117123 parser .add_argument ("--name" , required = True , help = "Prefix for result file e.g. spark/comet/gluten" )
124+ parser .add_argument ("--query" , required = False , type = int , help = "Specific query number to run (1-based). If not specified, all queries will be run." )
118125 args = parser .parse_args ()
119126
120- main (args .benchmark , args .data , args .queries , int (args .iterations ), args .output , args .name )
127+ main (args .benchmark , args .data , args .queries , int (args .iterations ), args .output , args .name , args . query )
0 commit comments