Skip to content

Commit 3fd1be2

Browse files
committed
feat: measure on-disk shuffle sizes and add short-strings data generator
- Add --schema short-strings option to generate_data.py that produces 7 short random UUID string columns + 1 timestamp, matching the schema from issue #3882 - Update shuffle_size.py to measure actual shuffle .data file sizes on disk via spark.local.dir, in addition to the REST API metric - Update run_shuffle_size_benchmark.sh with dedicated local dirs per run, driver memory, and shuffle enable config
1 parent 296f338 commit 3fd1be2

File tree

3 files changed

+124
-18
lines changed

3 files changed

+124
-18
lines changed

benchmarks/pyspark/benchmarks/shuffle_size.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
"""
2020
Shuffle size benchmark for measuring shuffle write bytes.
2121
22-
Measures the actual shuffle write bytes reported by Spark to compare
22+
Measures the actual shuffle file sizes on disk to compare
2323
shuffle file sizes between Spark and Comet shuffle implementations.
2424
This is useful for investigating shuffle format overhead (see issue #3882).
25+
26+
The benchmark sets spark.local.dir to a dedicated temp directory and
27+
measures the total size of shuffle data files (.data) written there.
2528
"""
2629

2730
import json
31+
import os
2832
import urllib.request
2933
from typing import Dict, Any
3034

@@ -43,6 +47,16 @@ def get_shuffle_write_bytes(spark) -> int:
4347
return sum(s.get("shuffleWriteBytes", 0) for s in stages)
4448

4549

50+
def get_shuffle_disk_bytes(local_dir: str) -> int:
51+
"""Walk spark.local.dir and sum the sizes of all shuffle .data files."""
52+
total = 0
53+
for root, _dirs, files in os.walk(local_dir):
54+
for f in files:
55+
if f.endswith(".data"):
56+
total += os.path.getsize(os.path.join(root, f))
57+
return total
58+
59+
4660
def format_bytes(b: int) -> str:
4761
"""Format byte count as human-readable string."""
4862
if b >= 1024 ** 3:
@@ -55,11 +69,17 @@ def format_bytes(b: int) -> str:
5569

5670
class ShuffleSizeBenchmark(Benchmark):
5771
"""
58-
Benchmark that measures shuffle write bytes via the Spark REST API.
72+
Benchmark that measures shuffle write bytes on disk.
73+
74+
Runs a simple scan -> repartition -> count pipeline and reports
75+
the actual shuffle data file sizes alongside the Spark REST API
76+
metric. Useful for comparing shuffle format overhead between
77+
Spark and Comet.
5978
60-
Runs a simple scan -> repartition -> write pipeline and reports
61-
the shuffle write size alongside wall-clock time. Useful for
62-
comparing shuffle format overhead between Spark and Comet.
79+
NOTE: The Spark session must be configured with spark.local.dir
80+
pointing to a dedicated empty directory so that we can measure
81+
shuffle file sizes accurately. The run_shuffle_size_benchmark.sh
82+
script handles this automatically.
6383
"""
6484

6585
def __init__(self, spark, data_path: str, mode: str,
@@ -73,7 +93,7 @@ def name(cls) -> str:
7393

7494
@classmethod
7595
def description(cls) -> str:
76-
return "Measure shuffle write bytes (scan -> repartition -> write)"
96+
return "Measure shuffle write bytes (scan -> repartition -> count)"
7797

7898
def run(self) -> Dict[str, Any]:
7999
df = self.spark.read.parquet(self.data_path)
@@ -85,6 +105,11 @@ def run(self) -> Dict[str, Any]:
85105
)
86106
print(f"Schema: {schema_desc}")
87107

108+
# Read spark.local.dir so we can measure shuffle files on disk
109+
local_dir = self.spark.sparkContext.getConf().get(
110+
"spark.local.dir", "/tmp"
111+
)
112+
88113
output_path = (
89114
f"/tmp/shuffle-size-benchmark-output-{self.mode}"
90115
)
@@ -96,23 +121,32 @@ def benchmark_operation():
96121

97122
duration_ms = self._time_operation(benchmark_operation)
98123

99-
shuffle_write_bytes = 0
124+
# Measure actual shuffle file sizes on disk.
125+
# Shuffle .data files persist until SparkContext shutdown,
126+
# so they are still available after the job completes.
127+
disk_bytes = get_shuffle_disk_bytes(local_dir)
128+
129+
# Also grab the REST API metric for comparison
130+
api_bytes = 0
100131
try:
101-
shuffle_write_bytes = get_shuffle_write_bytes(self.spark)
132+
api_bytes = get_shuffle_write_bytes(self.spark)
102133
except Exception as e:
103-
print(f"Warning: could not read shuffle metrics: {e}")
134+
print(f"Warning: could not read shuffle metrics from REST API: {e}")
104135

105-
bytes_per_record = (
106-
shuffle_write_bytes / row_count if row_count > 0 else 0
107-
)
136+
disk_bpr = disk_bytes / row_count if row_count > 0 else 0
137+
api_bpr = api_bytes / row_count if row_count > 0 else 0
108138

109-
print(f"Shuffle write: {format_bytes(shuffle_write_bytes)}")
110-
print(f"Bytes/record: {bytes_per_record:.1f}")
139+
print(f"Shuffle disk: {format_bytes(disk_bytes)} "
140+
f"({disk_bpr:.1f} B/record)")
141+
print(f"Shuffle API metric: {format_bytes(api_bytes)} "
142+
f"({api_bpr:.1f} B/record)")
111143

112144
return {
113145
"duration_ms": duration_ms,
114146
"row_count": row_count,
115147
"num_partitions": self.num_partitions,
116-
"shuffle_write_bytes": shuffle_write_bytes,
117-
"bytes_per_record": round(bytes_per_record, 1),
148+
"shuffle_disk_bytes": disk_bytes,
149+
"shuffle_disk_bytes_per_record": round(disk_bpr, 1),
150+
"shuffle_api_bytes": api_bytes,
151+
"shuffle_api_bytes_per_record": round(api_bpr, 1),
118152
}

benchmarks/pyspark/generate_data.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,53 @@ def generate_data(output_path: str, num_rows: int, num_partitions: int):
412412
spark.stop()
413413

414414

415+
def generate_short_strings_data(output_path: str, num_rows: int,
416+
num_partitions: int):
417+
"""Generate data matching the schema from issue #3882.
418+
419+
Reproduces the problematic scenario: 7 short unique string columns + 1
420+
timestamp column. The original reporter saw 3x shuffle overhead with
421+
204M records of this shape (25.1 B/record in Comet vs 8.3 B/record in
422+
Spark).
423+
"""
424+
425+
spark = SparkSession.builder \
426+
.appName("ShuffleBenchmark-DataGen-ShortStrings") \
427+
.getOrCreate()
428+
429+
print(f"Generating {num_rows:,} rows with {num_partitions} partitions")
430+
print(f"Output path: {output_path}")
431+
print("Schema: 7 short unique string columns + 1 timestamp (issue #3882)")
432+
433+
df = spark.range(0, num_rows, numPartitions=num_partitions)
434+
435+
# 7 short random string columns + 1 timestamp, mimicking the reporter's
436+
# schema. Uses uuid() to generate truly random strings that defeat
437+
# compression, exposing Arrow IPC per-batch overhead.
438+
df = df.selectExpr(
439+
"substring(uuid(), 1, 8) as str_col_1",
440+
"substring(uuid(), 1, 8) as str_col_2",
441+
"substring(uuid(), 1, 8) as str_col_3",
442+
"substring(uuid(), 1, 8) as str_col_4",
443+
"substring(uuid(), 1, 8) as str_col_5",
444+
"substring(uuid(), 1, 8) as str_col_6",
445+
"substring(uuid(), 1, 8) as str_col_7",
446+
# Timestamp column
447+
"timestamp_seconds(1600000000 + id) as ts_col",
448+
)
449+
450+
print(f"Generated schema with {len(df.columns)} columns")
451+
df.printSchema()
452+
453+
df.write.mode("overwrite").parquet(output_path)
454+
455+
written_df = spark.read.parquet(output_path)
456+
actual_count = written_df.count()
457+
print(f"Wrote {actual_count:,} rows to {output_path}")
458+
459+
spark.stop()
460+
461+
415462
def main():
416463
parser = argparse.ArgumentParser(
417464
description="Generate test data for shuffle benchmark"
@@ -433,13 +480,24 @@ def main():
433480
default=None,
434481
help="Number of output partitions (default: auto based on cluster)"
435482
)
483+
parser.add_argument(
484+
"--schema", "-s",
485+
choices=["wide", "short-strings"],
486+
default="wide",
487+
help="Schema to generate: 'wide' (100 columns with nested types) "
488+
"or 'short-strings' (7 short unique strings + 1 timestamp, "
489+
"matches issue #3882)"
490+
)
436491

437492
args = parser.parse_args()
438493

439494
# Default partitions to a reasonable number if not specified
440495
num_partitions = args.partitions if args.partitions else 200
441496

442-
generate_data(args.output, args.rows, num_partitions)
497+
if args.schema == "short-strings":
498+
generate_short_strings_data(args.output, args.rows, num_partitions)
499+
else:
500+
generate_data(args.output, args.rows, num_partitions)
443501

444502

445503
if __name__ == "__main__":

benchmarks/pyspark/run_shuffle_size_benchmark.sh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,23 @@ echo "Executor memory: $EXECUTOR_MEMORY"
6868
echo "Off-heap size: $OFFHEAP_SIZE"
6969
echo "========================================"
7070

71+
# Use dedicated local dirs so we can measure actual shuffle file sizes on disk
72+
SPARK_LOCAL_DIR=$(mktemp -d /tmp/spark-shuffle-bench-spark-XXXXXX)
73+
COMET_LOCAL_DIR=$(mktemp -d /tmp/spark-shuffle-bench-comet-XXXXXX)
74+
75+
cleanup() {
76+
rm -rf "$SPARK_LOCAL_DIR" "$COMET_LOCAL_DIR"
77+
}
78+
trap cleanup EXIT
79+
7180
# Run Spark baseline (no Comet)
7281
echo ""
7382
echo ">>> Running SPARK (no Comet) shuffle size benchmark..."
7483
$SPARK_HOME/bin/spark-submit \
7584
--master "$SPARK_MASTER" \
85+
--driver-memory "$EXECUTOR_MEMORY" \
7686
--executor-memory "$EXECUTOR_MEMORY" \
87+
--conf spark.local.dir="$SPARK_LOCAL_DIR" \
7788
--conf spark.comet.enabled=false \
7889
"$SCRIPT_DIR/run_benchmark.py" \
7990
--data "$DATA_PATH" \
@@ -85,16 +96,19 @@ echo ""
8596
echo ">>> Running COMET NATIVE shuffle size benchmark..."
8697
$SPARK_HOME/bin/spark-submit \
8798
--master "$SPARK_MASTER" \
99+
--driver-memory "$EXECUTOR_MEMORY" \
88100
--executor-memory "$EXECUTOR_MEMORY" \
89101
--jars "$COMET_JAR" \
90102
--driver-class-path "$COMET_JAR" \
91103
--conf spark.executor.extraClassPath="$COMET_JAR" \
104+
--conf spark.local.dir="$COMET_LOCAL_DIR" \
92105
--conf spark.plugins=org.apache.spark.CometPlugin \
93106
--conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
94107
--conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \
95108
--conf spark.memory.offHeap.enabled=true \
96109
--conf spark.memory.offHeap.size="$OFFHEAP_SIZE" \
97110
--conf spark.comet.enabled=true \
111+
--conf spark.comet.exec.shuffle.enabled=true \
98112
--conf spark.comet.exec.shuffle.mode=native \
99113
--conf spark.comet.explainFallback.enabled=true \
100114
"$SCRIPT_DIR/run_benchmark.py" \
@@ -106,4 +120,4 @@ echo ""
106120
echo "========================================"
107121
echo "BENCHMARK COMPLETE"
108122
echo "========================================"
109-
echo "Compare 'Shuffle write' and 'Bytes/record' between the two runs above."
123+
echo "Compare 'Shuffle disk' bytes/record between the two runs above."

0 commit comments

Comments
 (0)