Skip to content

Commit efca991

Browse files
authored
fix: Making shuffle files generated in native shuffle mode reclaimable (#1568)
* Making shuffle files generated in native shuffle mode reclaimable * Add a unit test * Use eventually in unit test * Address review comments
1 parent c3f6714 commit efca991

File tree

5 files changed

+321
-224
lines changed

5 files changed

+321
-224
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.execution.shuffle
21+
22+
import java.nio.{ByteBuffer, ByteOrder}
23+
import java.nio.file.{Files, Paths}
24+
25+
import scala.collection.JavaConverters.asJavaIterableConverter
26+
27+
import org.apache.spark.{SparkEnv, TaskContext}
28+
import org.apache.spark.scheduler.MapStatus
29+
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriter}
30+
import org.apache.spark.sql.catalyst.expressions.Attribute
31+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
32+
import org.apache.spark.sql.comet.{CometExec, CometMetricNode}
33+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleWriteMetricsReporter}
34+
import org.apache.spark.sql.vectorized.ColumnarBatch
35+
36+
import org.apache.comet.CometConf
37+
import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde}
38+
import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
39+
import org.apache.comet.serde.QueryPlanSerde.serializeDataType
40+
41+
/**
42+
* A [[ShuffleWriter]] that will delegate shuffle write to native shuffle.
43+
*/
44+
class CometNativeShuffleWriter[K, V](
45+
outputPartitioning: Partitioning,
46+
outputAttributes: Seq[Attribute],
47+
metrics: Map[String, SQLMetric],
48+
numParts: Int,
49+
shuffleId: Int,
50+
mapId: Long,
51+
context: TaskContext,
52+
metricsReporter: ShuffleWriteMetricsReporter)
53+
extends ShuffleWriter[K, V] {
54+
55+
private val OFFSET_LENGTH = 8
56+
57+
var partitionLengths: Array[Long] = _
58+
var mapStatus: MapStatus = _
59+
60+
override def write(inputs: Iterator[Product2[K, V]]): Unit = {
61+
val shuffleBlockResolver =
62+
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
63+
val dataFile = shuffleBlockResolver.getDataFile(shuffleId, mapId)
64+
val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId)
65+
val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp")
66+
val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp")
67+
val tempDataFilePath = Paths.get(tempDataFilename)
68+
val tempIndexFilePath = Paths.get(tempIndexFilename)
69+
70+
// Call native shuffle write
71+
val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
72+
73+
val detailedMetrics = Seq(
74+
"elapsed_compute",
75+
"encode_time",
76+
"repart_time",
77+
"mempool_time",
78+
"input_batches",
79+
"spill_count",
80+
"spilled_bytes")
81+
82+
// Maps native metrics to SQL metrics
83+
val nativeSQLMetrics = Map(
84+
"output_rows" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
85+
"data_size" -> metrics("dataSize"),
86+
"write_time" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
87+
metrics.filterKeys(detailedMetrics.contains)
88+
val nativeMetrics = CometMetricNode(nativeSQLMetrics)
89+
90+
// Getting rid of the fake partitionId
91+
val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2)
92+
93+
val cometIter = CometExec.getCometIterator(
94+
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
95+
outputAttributes.length,
96+
nativePlan,
97+
nativeMetrics,
98+
numParts,
99+
context.partitionId())
100+
101+
while (cometIter.hasNext) {
102+
cometIter.next()
103+
}
104+
cometIter.close()
105+
106+
// get partition lengths from shuffle write output index file
107+
var offset = 0L
108+
partitionLengths = Files
109+
.readAllBytes(tempIndexFilePath)
110+
.grouped(OFFSET_LENGTH)
111+
.drop(1) // first partition offset is always 0
112+
.map(indexBytes => {
113+
val partitionOffset =
114+
ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong
115+
val partitionLength = partitionOffset - offset
116+
offset = partitionOffset
117+
partitionLength
118+
})
119+
.toArray
120+
Files.delete(tempIndexFilePath)
121+
122+
// Total written bytes at native
123+
metricsReporter.incBytesWritten(Files.size(tempDataFilePath))
124+
125+
// commit
126+
shuffleBlockResolver.writeMetadataFileAndCommit(
127+
shuffleId,
128+
mapId,
129+
partitionLengths,
130+
Array.empty, // TODO: add checksums
131+
tempDataFilePath.toFile)
132+
mapStatus =
133+
MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId)
134+
}
135+
136+
private def getNativePlan(dataFile: String, indexFile: String): Operator = {
137+
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
138+
val opBuilder = OperatorOuterClass.Operator.newBuilder()
139+
140+
val scanTypes = outputAttributes.flatten { attr =>
141+
serializeDataType(attr.dataType)
142+
}
143+
144+
if (scanTypes.length == outputAttributes.length) {
145+
scanBuilder.addAllFields(scanTypes.asJava)
146+
147+
val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
148+
shuffleWriterBuilder.setOutputDataFile(dataFile)
149+
shuffleWriterBuilder.setOutputIndexFile(indexFile)
150+
shuffleWriterBuilder.setEnableFastEncoding(
151+
CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get())
152+
153+
if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
154+
val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match {
155+
case "zstd" => CompressionCodec.Zstd
156+
case "lz4" => CompressionCodec.Lz4
157+
case "snappy" => CompressionCodec.Snappy
158+
case other => throw new UnsupportedOperationException(s"invalid codec: $other")
159+
}
160+
shuffleWriterBuilder.setCodec(codec)
161+
} else {
162+
shuffleWriterBuilder.setCodec(CompressionCodec.None)
163+
}
164+
shuffleWriterBuilder.setCompressionLevel(
165+
CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
166+
167+
outputPartitioning match {
168+
case _: HashPartitioning =>
169+
val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning]
170+
171+
val partitioning = PartitioningOuterClass.HashRepartition.newBuilder()
172+
partitioning.setNumPartitions(outputPartitioning.numPartitions)
173+
174+
val partitionExprs = hashPartitioning.expressions
175+
.flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
176+
177+
if (partitionExprs.length != hashPartitioning.expressions.length) {
178+
throw new UnsupportedOperationException(
179+
s"Partitioning $hashPartitioning is not supported.")
180+
}
181+
182+
partitioning.addAllHashExpression(partitionExprs.asJava)
183+
184+
val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
185+
shuffleWriterBuilder.setPartitioning(
186+
partitioningBuilder.setHashPartition(partitioning).build())
187+
188+
case SinglePartition =>
189+
val partitioning = PartitioningOuterClass.SinglePartition.newBuilder()
190+
191+
val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
192+
shuffleWriterBuilder.setPartitioning(
193+
partitioningBuilder.setSinglePartition(partitioning).build())
194+
195+
case _ =>
196+
throw new UnsupportedOperationException(
197+
s"Partitioning $outputPartitioning is not supported.")
198+
}
199+
200+
val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
201+
shuffleWriterOpBuilder
202+
.setShuffleWriter(shuffleWriterBuilder)
203+
.addChildren(opBuilder.setScan(scanBuilder).build())
204+
.build()
205+
} else {
206+
// There are unsupported scan type
207+
throw new UnsupportedOperationException(
208+
s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.")
209+
}
210+
}
211+
212+
override def stop(success: Boolean): Option[MapStatus] = {
213+
if (success) {
214+
Some(mapStatus)
215+
} else {
216+
None
217+
}
218+
}
219+
220+
override def getPartitionLengths(): Array[Long] = partitionLengths
221+
}

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency, SparkEnv}
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.serializer.Serializer
2727
import org.apache.spark.shuffle.ShuffleWriteProcessor
28+
import org.apache.spark.sql.catalyst.expressions.Attribute
29+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2830
import org.apache.spark.sql.execution.metric.SQLMetric
2931
import org.apache.spark.sql.types.StructType
3032

@@ -41,7 +43,11 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
4143
override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
4244
val shuffleType: ShuffleType = CometNativeShuffle,
4345
val schema: Option[StructType] = None,
44-
val decodeTime: SQLMetric)
46+
val decodeTime: SQLMetric,
47+
val outputPartitioning: Option[Partitioning] = None,
48+
val outputAttributes: Seq[Attribute] = Seq.empty,
49+
val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty,
50+
val numParts: Int = 0)
4551
extends ShuffleDependency[K, V, C](
4652
_rdd,
4753
partitioner,

0 commit comments

Comments
 (0)