|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.spark.sql.comet.execution.shuffle |
| 21 | + |
| 22 | +import java.nio.{ByteBuffer, ByteOrder} |
| 23 | +import java.nio.file.{Files, Paths} |
| 24 | + |
| 25 | +import scala.collection.JavaConverters.asJavaIterableConverter |
| 26 | + |
| 27 | +import org.apache.spark.{SparkEnv, TaskContext} |
| 28 | +import org.apache.spark.scheduler.MapStatus |
| 29 | +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriter} |
| 30 | +import org.apache.spark.sql.catalyst.expressions.Attribute |
| 31 | +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} |
| 32 | +import org.apache.spark.sql.comet.{CometExec, CometMetricNode} |
| 33 | +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleWriteMetricsReporter} |
| 34 | +import org.apache.spark.sql.vectorized.ColumnarBatch |
| 35 | + |
| 36 | +import org.apache.comet.CometConf |
| 37 | +import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} |
| 38 | +import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} |
| 39 | +import org.apache.comet.serde.QueryPlanSerde.serializeDataType |
| 40 | + |
| 41 | +/** |
| 42 | + * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. |
| 43 | + */ |
| 44 | +class CometNativeShuffleWriter[K, V]( |
| 45 | + outputPartitioning: Partitioning, |
| 46 | + outputAttributes: Seq[Attribute], |
| 47 | + metrics: Map[String, SQLMetric], |
| 48 | + numParts: Int, |
| 49 | + shuffleId: Int, |
| 50 | + mapId: Long, |
| 51 | + context: TaskContext, |
| 52 | + metricsReporter: ShuffleWriteMetricsReporter) |
| 53 | + extends ShuffleWriter[K, V] { |
| 54 | + |
| 55 | + private val OFFSET_LENGTH = 8 |
| 56 | + |
| 57 | + var partitionLengths: Array[Long] = _ |
| 58 | + var mapStatus: MapStatus = _ |
| 59 | + |
| 60 | + override def write(inputs: Iterator[Product2[K, V]]): Unit = { |
| 61 | + val shuffleBlockResolver = |
| 62 | + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] |
| 63 | + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, mapId) |
| 64 | + val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId) |
| 65 | + val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp") |
| 66 | + val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp") |
| 67 | + val tempDataFilePath = Paths.get(tempDataFilename) |
| 68 | + val tempIndexFilePath = Paths.get(tempIndexFilename) |
| 69 | + |
| 70 | + // Call native shuffle write |
| 71 | + val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) |
| 72 | + |
| 73 | + val detailedMetrics = Seq( |
| 74 | + "elapsed_compute", |
| 75 | + "encode_time", |
| 76 | + "repart_time", |
| 77 | + "mempool_time", |
| 78 | + "input_batches", |
| 79 | + "spill_count", |
| 80 | + "spilled_bytes") |
| 81 | + |
| 82 | + // Maps native metrics to SQL metrics |
| 83 | + val nativeSQLMetrics = Map( |
| 84 | + "output_rows" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN), |
| 85 | + "data_size" -> metrics("dataSize"), |
| 86 | + "write_time" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++ |
| 87 | + metrics.filterKeys(detailedMetrics.contains) |
| 88 | + val nativeMetrics = CometMetricNode(nativeSQLMetrics) |
| 89 | + |
| 90 | + // Getting rid of the fake partitionId |
| 91 | + val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) |
| 92 | + |
| 93 | + val cometIter = CometExec.getCometIterator( |
| 94 | + Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), |
| 95 | + outputAttributes.length, |
| 96 | + nativePlan, |
| 97 | + nativeMetrics, |
| 98 | + numParts, |
| 99 | + context.partitionId()) |
| 100 | + |
| 101 | + while (cometIter.hasNext) { |
| 102 | + cometIter.next() |
| 103 | + } |
| 104 | + cometIter.close() |
| 105 | + |
| 106 | + // get partition lengths from shuffle write output index file |
| 107 | + var offset = 0L |
| 108 | + partitionLengths = Files |
| 109 | + .readAllBytes(tempIndexFilePath) |
| 110 | + .grouped(OFFSET_LENGTH) |
| 111 | + .drop(1) // first partition offset is always 0 |
| 112 | + .map(indexBytes => { |
| 113 | + val partitionOffset = |
| 114 | + ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong |
| 115 | + val partitionLength = partitionOffset - offset |
| 116 | + offset = partitionOffset |
| 117 | + partitionLength |
| 118 | + }) |
| 119 | + .toArray |
| 120 | + Files.delete(tempIndexFilePath) |
| 121 | + |
| 122 | + // Total written bytes at native |
| 123 | + metricsReporter.incBytesWritten(Files.size(tempDataFilePath)) |
| 124 | + |
| 125 | + // commit |
| 126 | + shuffleBlockResolver.writeMetadataFileAndCommit( |
| 127 | + shuffleId, |
| 128 | + mapId, |
| 129 | + partitionLengths, |
| 130 | + Array.empty, // TODO: add checksums |
| 131 | + tempDataFilePath.toFile) |
| 132 | + mapStatus = |
| 133 | + MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId) |
| 134 | + } |
| 135 | + |
| 136 | + private def getNativePlan(dataFile: String, indexFile: String): Operator = { |
| 137 | + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") |
| 138 | + val opBuilder = OperatorOuterClass.Operator.newBuilder() |
| 139 | + |
| 140 | + val scanTypes = outputAttributes.flatten { attr => |
| 141 | + serializeDataType(attr.dataType) |
| 142 | + } |
| 143 | + |
| 144 | + if (scanTypes.length == outputAttributes.length) { |
| 145 | + scanBuilder.addAllFields(scanTypes.asJava) |
| 146 | + |
| 147 | + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() |
| 148 | + shuffleWriterBuilder.setOutputDataFile(dataFile) |
| 149 | + shuffleWriterBuilder.setOutputIndexFile(indexFile) |
| 150 | + shuffleWriterBuilder.setEnableFastEncoding( |
| 151 | + CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get()) |
| 152 | + |
| 153 | + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { |
| 154 | + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { |
| 155 | + case "zstd" => CompressionCodec.Zstd |
| 156 | + case "lz4" => CompressionCodec.Lz4 |
| 157 | + case "snappy" => CompressionCodec.Snappy |
| 158 | + case other => throw new UnsupportedOperationException(s"invalid codec: $other") |
| 159 | + } |
| 160 | + shuffleWriterBuilder.setCodec(codec) |
| 161 | + } else { |
| 162 | + shuffleWriterBuilder.setCodec(CompressionCodec.None) |
| 163 | + } |
| 164 | + shuffleWriterBuilder.setCompressionLevel( |
| 165 | + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) |
| 166 | + |
| 167 | + outputPartitioning match { |
| 168 | + case _: HashPartitioning => |
| 169 | + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] |
| 170 | + |
| 171 | + val partitioning = PartitioningOuterClass.HashRepartition.newBuilder() |
| 172 | + partitioning.setNumPartitions(outputPartitioning.numPartitions) |
| 173 | + |
| 174 | + val partitionExprs = hashPartitioning.expressions |
| 175 | + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) |
| 176 | + |
| 177 | + if (partitionExprs.length != hashPartitioning.expressions.length) { |
| 178 | + throw new UnsupportedOperationException( |
| 179 | + s"Partitioning $hashPartitioning is not supported.") |
| 180 | + } |
| 181 | + |
| 182 | + partitioning.addAllHashExpression(partitionExprs.asJava) |
| 183 | + |
| 184 | + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() |
| 185 | + shuffleWriterBuilder.setPartitioning( |
| 186 | + partitioningBuilder.setHashPartition(partitioning).build()) |
| 187 | + |
| 188 | + case SinglePartition => |
| 189 | + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() |
| 190 | + |
| 191 | + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() |
| 192 | + shuffleWriterBuilder.setPartitioning( |
| 193 | + partitioningBuilder.setSinglePartition(partitioning).build()) |
| 194 | + |
| 195 | + case _ => |
| 196 | + throw new UnsupportedOperationException( |
| 197 | + s"Partitioning $outputPartitioning is not supported.") |
| 198 | + } |
| 199 | + |
| 200 | + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() |
| 201 | + shuffleWriterOpBuilder |
| 202 | + .setShuffleWriter(shuffleWriterBuilder) |
| 203 | + .addChildren(opBuilder.setScan(scanBuilder).build()) |
| 204 | + .build() |
| 205 | + } else { |
| 206 | + // There are unsupported scan type |
| 207 | + throw new UnsupportedOperationException( |
| 208 | + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + override def stop(success: Boolean): Option[MapStatus] = { |
| 213 | + if (success) { |
| 214 | + Some(mapStatus) |
| 215 | + } else { |
| 216 | + None |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + override def getPartitionLengths(): Array[Long] = partitionLengths |
| 221 | +} |
0 commit comments