Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions benchmarks/cdk/bin/spark-bench.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import path from "path";
import {Command} from "commander";
import {z} from 'zod';
import {BenchmarkRunner, ROOT, runBenchmark, TableSpec} from "./@bench-common";

// Remember to port-forward the Spark HTTP server with
// aws ssm start-session --target {host-id} --document-name AWS-StartPortForwardingSession --parameters "portNumber=9003,localPortNumber=9003"

async function main() {
const program = new Command();

program
.option('--dataset <string>', 'Dataset to run queries on')
.option('-i, --iterations <number>', 'Number of iterations', '3')
.option('--query <number>', 'A specific query to run', undefined)
.parse(process.argv);

const options = program.opts();

const dataset: string = options.dataset
const iterations = parseInt(options.iterations);
const queries = options.query ? [parseInt(options.query)] : [];

const runner = new SparkRunner({});

const datasetPath = path.join(ROOT, "benchmarks", "data", dataset);
const outputPath = path.join(datasetPath, "remote-results.json")

await runBenchmark(runner, {
dataset,
iterations,
queries,
outputPath,
});
}

const QueryResponse = z.object({
count: z.number()
})
type QueryResponse = z.infer<typeof QueryResponse>

class SparkRunner implements BenchmarkRunner {
private url = 'http://localhost:9003';

constructor(private readonly options: {}) {
}

async executeQuery(sql: string): Promise<{ rowCount: number }> {
// Fix TPCH query 4: Add DATE prefix to date literals
sql = sql.replace(/(?<!date\s)('[\d]{4}-[\d]{2}-[\d]{2}')/gi, 'DATE $1');

// Fix ClickBench queries: Spark uses from_unixtime
sql = sql.replace(/to_timestamp_seconds\(/gi, 'from_unixtime(');

let response
if (sql.includes("create view")) {
// Query 15
let [createView, query, dropView] = sql.split(";")
await this.query(createView);
response = await this.query(query)
await this.query(dropView);
} else {
response = await this.query(sql)
}

return { rowCount: response.count };
}

private async query(sql: string): Promise<QueryResponse> {
const response = await fetch(`${this.url}/query`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
query: sql.trim().replace(/;+$/, '')
})
});

if (!response.ok) {
const msg = await response.text();
throw new Error(`Query failed: ${response.status} ${msg}`);
}

return QueryResponse.parse(await response.json());
}

async createTables(tables: TableSpec[]): Promise<void> {
for (const table of tables) {
// Spark requires s3a:// protocol, not s3://
const s3aPath = table.s3Path.replace('s3://', 's3a://');

// Create temporary view from Parquet files
const createViewStmt = `
CREATE OR REPLACE TEMPORARY VIEW ${table.name}
USING parquet
OPTIONS (path '${s3aPath}')
`;
await this.query(createViewStmt);
}
}

}

main()
.catch(err => {
console.error(err)
process.exit(1)
})
60 changes: 60 additions & 0 deletions benchmarks/cdk/bin/spark_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import os
from flask import Flask, request, jsonify
from pyspark.sql import SparkSession

app = Flask(__name__)

# Initialize Spark session
spark = None

def get_spark():
global spark
if spark is None:
master_host = os.environ.get('SPARK_MASTER_HOST', 'localhost')
spark_jars = os.environ.get('SPARK_JARS', '/opt/spark/jars/hadoop-aws-3.4.1.jar,/opt/spark/jars/bundle-2.29.52.jar,/opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar')
spark = SparkSession.builder \
.appName("SparkHTTPServer") \
.master(f"spark://{master_host}:7077") \
.config("spark.jars", spark_jars) \
.config("spark.sql.catalogImplementation", "hive") \
.config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
.config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.InstanceProfileCredentialsProvider") \
.enableHiveSupport() \
.getOrCreate()

# Set log level to reduce noise
spark.sparkContext.setLogLevel("WARN")
return spark

@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({"status": "healthy"}), 200

@app.route('/query', methods=['POST'])
def execute_query():
"""Execute a SQL query on Spark"""
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({"error": "Missing 'query' in request body"}), 400

query = data['query']

# Execute the query
spark_session = get_spark()
df = spark_session.sql(query)

# Get row count without collecting all data
count = df.count()

return jsonify({"count": count}), 200

except Exception as e:
return str(e), 500

if __name__ == '__main__':
# Run Flask server on port 9000
port = int(os.environ.get('HTTP_PORT', 9003))
app.run(host='0.0.0.0', port=port, debug=False)
4 changes: 3 additions & 1 deletion benchmarks/cdk/lib/cdk-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {Construct} from 'constructs';
import {DATAFUSION_DISTRIBUTED_ENGINE} from "./datafusion-distributed";
import {BALLISTA_ENGINE} from "./ballista";
import {TRINO_ENGINE} from "./trino";
import {SPARK_ENGINE} from "./spark";
import path from "path";
import * as cr from "aws-cdk-lib/custom-resources";

Expand All @@ -17,7 +18,8 @@ if (USER_DATA_CAUSES_REPLACEMENT) {
const ENGINES = [
DATAFUSION_DISTRIBUTED_ENGINE,
BALLISTA_ENGINE,
TRINO_ENGINE
TRINO_ENGINE,
SPARK_ENGINE
]

export const ROOT = path.join(__dirname, '../../..')
Expand Down
Loading