Skip to content

Commit a89de5f

Browse files
authored
Add Spark to remote benchmarks (#280)
* Add ballista to benchmarks * Add Spark to benchmarks
1 parent 086b89f commit a89de5f

File tree

6 files changed

+518
-2
lines changed

6 files changed

+518
-2
lines changed

benchmarks/cdk/bin/spark-bench.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import path from "path";
2+
import {Command} from "commander";
3+
import {z} from 'zod';
4+
import {BenchmarkRunner, ROOT, runBenchmark, TableSpec} from "./@bench-common";
5+
6+
// Remember to port-forward the Spark HTTP server with
7+
// aws ssm start-session --target {host-id} --document-name AWS-StartPortForwardingSession --parameters "portNumber=9003,localPortNumber=9003"
8+
9+
async function main() {
10+
const program = new Command();
11+
12+
program
13+
.option('--dataset <string>', 'Dataset to run queries on')
14+
.option('-i, --iterations <number>', 'Number of iterations', '3')
15+
.option('--query <number>', 'A specific query to run', undefined)
16+
.parse(process.argv);
17+
18+
const options = program.opts();
19+
20+
const dataset: string = options.dataset
21+
const iterations = parseInt(options.iterations);
22+
const queries = options.query ? [parseInt(options.query)] : [];
23+
24+
const runner = new SparkRunner({});
25+
26+
const datasetPath = path.join(ROOT, "benchmarks", "data", dataset);
27+
const outputPath = path.join(datasetPath, "remote-results.json")
28+
29+
await runBenchmark(runner, {
30+
dataset,
31+
iterations,
32+
queries,
33+
outputPath,
34+
});
35+
}
36+
37+
const QueryResponse = z.object({
38+
count: z.number()
39+
})
40+
type QueryResponse = z.infer<typeof QueryResponse>
41+
42+
class SparkRunner implements BenchmarkRunner {
43+
private url = 'http://localhost:9003';
44+
45+
constructor(private readonly options: {}) {
46+
}
47+
48+
async executeQuery(sql: string): Promise<{ rowCount: number }> {
49+
// Fix TPCH query 4: Add DATE prefix to date literals
50+
sql = sql.replace(/(?<!date\s)('[\d]{4}-[\d]{2}-[\d]{2}')/gi, 'DATE $1');
51+
52+
// Fix ClickBench queries: Spark uses from_unixtime
53+
sql = sql.replace(/to_timestamp_seconds\(/gi, 'from_unixtime(');
54+
55+
let response
56+
if (sql.includes("create view")) {
57+
// Query 15
58+
let [createView, query, dropView] = sql.split(";")
59+
await this.query(createView);
60+
response = await this.query(query)
61+
await this.query(dropView);
62+
} else {
63+
response = await this.query(sql)
64+
}
65+
66+
return { rowCount: response.count };
67+
}
68+
69+
private async query(sql: string): Promise<QueryResponse> {
70+
const response = await fetch(`${this.url}/query`, {
71+
method: 'POST',
72+
headers: {
73+
'Content-Type': 'application/json',
74+
},
75+
body: JSON.stringify({
76+
query: sql.trim().replace(/;+$/, '')
77+
})
78+
});
79+
80+
if (!response.ok) {
81+
const msg = await response.text();
82+
throw new Error(`Query failed: ${response.status} ${msg}`);
83+
}
84+
85+
return QueryResponse.parse(await response.json());
86+
}
87+
88+
async createTables(tables: TableSpec[]): Promise<void> {
89+
for (const table of tables) {
90+
// Spark requires s3a:// protocol, not s3://
91+
const s3aPath = table.s3Path.replace('s3://', 's3a://');
92+
93+
// Create temporary view from Parquet files
94+
const createViewStmt = `
95+
CREATE OR REPLACE TEMPORARY VIEW ${table.name}
96+
USING parquet
97+
OPTIONS (path '${s3aPath}')
98+
`;
99+
await this.query(createViewStmt);
100+
}
101+
}
102+
103+
}
104+
105+
main()
106+
.catch(err => {
107+
console.error(err)
108+
process.exit(1)
109+
})

benchmarks/cdk/bin/spark_http.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python3
2+
import os
3+
from flask import Flask, request, jsonify
4+
from pyspark.sql import SparkSession
5+
6+
app = Flask(__name__)
7+
8+
# Initialize Spark session
9+
spark = None
10+
11+
def get_spark():
12+
global spark
13+
if spark is None:
14+
master_host = os.environ.get('SPARK_MASTER_HOST', 'localhost')
15+
spark_jars = os.environ.get('SPARK_JARS', '/opt/spark/jars/hadoop-aws-3.4.1.jar,/opt/spark/jars/bundle-2.29.52.jar,/opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar')
16+
spark = SparkSession.builder \
17+
.appName("SparkHTTPServer") \
18+
.master(f"spark://{master_host}:7077") \
19+
.config("spark.jars", spark_jars) \
20+
.config("spark.sql.catalogImplementation", "hive") \
21+
.config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
22+
.config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.InstanceProfileCredentialsProvider") \
23+
.enableHiveSupport() \
24+
.getOrCreate()
25+
26+
# Set log level to reduce noise
27+
spark.sparkContext.setLogLevel("WARN")
28+
return spark
29+
30+
@app.route('/health', methods=['GET'])
31+
def health():
32+
"""Health check endpoint"""
33+
return jsonify({"status": "healthy"}), 200
34+
35+
@app.route('/query', methods=['POST'])
36+
def execute_query():
37+
"""Execute a SQL query on Spark"""
38+
try:
39+
data = request.get_json()
40+
if not data or 'query' not in data:
41+
return jsonify({"error": "Missing 'query' in request body"}), 400
42+
43+
query = data['query']
44+
45+
# Execute the query
46+
spark_session = get_spark()
47+
df = spark_session.sql(query)
48+
49+
# Get row count without collecting all data
50+
count = df.count()
51+
52+
return jsonify({"count": count}), 200
53+
54+
except Exception as e:
55+
return str(e), 500
56+
57+
if __name__ == '__main__':
58+
# Run Flask server on port 9000
59+
port = int(os.environ.get('HTTP_PORT', 9003))
60+
app.run(host='0.0.0.0', port=port, debug=False)

benchmarks/cdk/lib/cdk-stack.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {Construct} from 'constructs';
66
import {DATAFUSION_DISTRIBUTED_ENGINE} from "./datafusion-distributed";
77
import {BALLISTA_ENGINE} from "./ballista";
88
import {TRINO_ENGINE} from "./trino";
9+
import {SPARK_ENGINE} from "./spark";
910
import path from "path";
1011
import * as cr from "aws-cdk-lib/custom-resources";
1112

@@ -17,7 +18,8 @@ if (USER_DATA_CAUSES_REPLACEMENT) {
1718
const ENGINES = [
1819
DATAFUSION_DISTRIBUTED_ENGINE,
1920
BALLISTA_ENGINE,
20-
TRINO_ENGINE
21+
TRINO_ENGINE,
22+
SPARK_ENGINE
2123
]
2224

2325
export const ROOT = path.join(__dirname, '../../..')

0 commit comments

Comments
 (0)