Skip to content

Commit cefa493

Browse files
authored
Picking changes from the fix for SPARK-32709 (#39)
* Picking changes from the fix for SPARK-32709 Applied patch from PR: apache#33432 * Fixing compile issues * fix compile failure Spark Hive module
1 parent b4400c7 commit cefa493

File tree

11 files changed

+230
-52
lines changed

11 files changed

+230
-52
lines changed

core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,33 @@ abstract class FileCommitProtocol extends Logging {
9292
*/
9393
def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
9494

95+
/**
96+
* Notifies the commit protocol to add a new file, and gets back the full path that should be
97+
* used. Must be called on the executors when running tasks.
98+
*
99+
* Note that the returned temp file may have an arbitrary path. The commit protocol only
100+
* promises that the file will be at the location specified by the arguments after job commit.
101+
*
102+
* The "dir" parameter specifies the sub-directory within the base path, used to specify
103+
* partitioning. The "spec" parameter specifies the file name. The rest are left to the commit
104+
* protocol implementation to decide.
105+
*
106+
* Important: it is the caller's responsibility to add uniquely identifying content to "spec"
107+
* if a task is going to write out multiple files to the same dir. The file commit protocol only
108+
* guarantees that files written by different tasks will not conflict.
109+
*
110+
* @since 3.2.0
111+
*/
112+
def newTaskTempFile(
113+
taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = {
114+
if (spec.prefix.isEmpty) {
115+
newTaskTempFile(taskContext, dir, spec.suffix)
116+
} else {
117+
throw new UnsupportedOperationException(s"${getClass.getSimpleName}.newTaskTempFile does " +
118+
s"not support file name prefix: ${spec.prefix}")
119+
}
120+
}
121+
95122
/**
96123
* Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
97124
* Depending on the implementation, there may be weaker guarantees around adding files this way.
@@ -103,6 +130,31 @@ abstract class FileCommitProtocol extends Logging {
103130
def newTaskTempFileAbsPath(
104131
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
105132

133+
/**
134+
* Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
135+
* Depending on the implementation, there may be weaker guarantees around adding files this way.
136+
*
137+
* The "absoluteDir" parameter specifies the final absolute directory of file. The "spec"
138+
* parameter specifies the file name. The rest are left to the commit protocol implementation to
139+
* decide.
140+
*
141+
* Important: it is the caller's responsibility to add uniquely identifying content to "spec"
142+
* if a task is going to write out multiple files to the same dir. The file commit protocol only
143+
* guarantees that files written by different tasks will not conflict.
144+
*
145+
* @since 3.2.0
146+
*/
147+
def newTaskTempFileAbsPath(
148+
taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = {
149+
if (spec.prefix.isEmpty) {
150+
newTaskTempFileAbsPath(taskContext, absoluteDir, spec.suffix)
151+
} else {
152+
throw new UnsupportedOperationException(
153+
s"${getClass.getSimpleName}.newTaskTempFileAbsPath does not support file name prefix: " +
154+
s"${spec.prefix}")
155+
}
156+
}
157+
106158
/**
107159
* Commits a task after the writes succeed. Must be called on the executors when running tasks.
108160
*/
@@ -140,6 +192,15 @@ object FileCommitProtocol extends Logging {
140192

141193
object EmptyTaskCommitMessage extends TaskCommitMessage(null)
142194

195+
/**
196+
* The specification for Spark output file name.
197+
* This is used by [[FileCommitProtocol]] to create full path of file.
198+
*
199+
* @param prefix Prefix of file.
200+
* @param suffix Suffix of file.
201+
*/
202+
final case class FileNameSpec(prefix: String, suffix: String)
203+
143204
/**
144205
* Instantiates a FileCommitProtocol using the given className.
145206
*/

core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ class HadoopMapReduceCommitProtocol(
118118

119119
override def newTaskTempFile(
120120
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
121-
val filename = getFilename(taskContext, ext)
121+
newTaskTempFile(taskContext, dir, FileNameSpec("", ext))
122+
}
123+
124+
override def newTaskTempFile(
125+
taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = {
126+
val filename = getFilename(taskContext, spec)
122127

123128
val stagingDir: Path = committer match {
124129
// For FileOutputCommitter it has its own staging path called "work path".
@@ -141,7 +146,12 @@ class HadoopMapReduceCommitProtocol(
141146

142147
override def newTaskTempFileAbsPath(
143148
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
144-
val filename = getFilename(taskContext, ext)
149+
newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext))
150+
}
151+
152+
override def newTaskTempFileAbsPath(
153+
taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = {
154+
val filename = getFilename(taskContext, spec)
145155
val absOutputPath = new Path(absoluteDir, filename).toString
146156

147157
// Include a UUID here to prevent file collisions for one task writing to different dirs.
@@ -152,12 +162,12 @@ class HadoopMapReduceCommitProtocol(
152162
tmpOutputPath
153163
}
154164

155-
protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
165+
protected def getFilename(taskContext: TaskAttemptContext, spec: FileNameSpec): String = {
156166
// The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
157167
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
158168
// the file name is fine and won't overflow.
159169
val split = taskContext.getTaskAttemptID.getTaskID.getId
160-
f"part-$split%05d-$jobId$ext"
170+
f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}"
161171
}
162172

163173
override def setupJob(jobContext: JobContext): Unit = {

hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
2121
import org.apache.hadoop.mapreduce.TaskAttemptContext
2222
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, PathOutputCommitter, PathOutputCommitterFactory}
2323

24+
import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec
2425
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
2526

2627
/**
@@ -122,20 +123,20 @@ class PathOutputCommitProtocol(
122123
*
123124
* @param taskContext task context
124125
* @param dir optional subdirectory
125-
* @param ext file extension
126+
* @param spec file naming specification
126127
* @return a path as a string
127128
*/
128129
override def newTaskTempFile(
129130
taskContext: TaskAttemptContext,
130131
dir: Option[String],
131-
ext: String): String = {
132+
spec: FileNameSpec): String = {
132133

133134
val workDir = committer.getWorkPath
134135
val parent = dir.map {
135136
d => new Path(workDir, d)
136137
}.getOrElse(workDir)
137-
val file = new Path(parent, getFilename(taskContext, ext))
138-
logTrace(s"Creating task file $file for dir $dir and ext $ext")
138+
val file = new Path(parent, getFilename(taskContext, spec))
139+
logTrace(s"Creating task file $file for dir $dir and spec $spec")
139140
file.toString
140141
}
141142

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
2222
import org.apache.hadoop.mapreduce.TaskAttemptContext
2323

2424
import org.apache.spark.internal.io.FileCommitProtocol
25+
import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec
2526
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
2627
import org.apache.spark.sql.catalyst.InternalRow
2728
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -157,7 +158,7 @@ class DynamicPartitionDataWriter(
157158
private val isPartitioned = description.partitionColumns.nonEmpty
158159

159160
/** Flag saying whether or not the data to be written out is bucketed. */
160-
private val isBucketed = description.bucketIdExpression.isDefined
161+
protected val isBucketed = description.bucketSpec.isDefined
161162

162163
assert(isPartitioned || isBucketed,
163164
s"""DynamicPartitionWriteTask should be used for writing out data that's either
@@ -196,7 +197,8 @@ class DynamicPartitionDataWriter(
196197
/** Given an input row, returns the corresponding `bucketId` */
197198
private lazy val getBucketId: InternalRow => Int = {
198199
val proj =
199-
UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
200+
UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression),
201+
description.allColumns)
200202
row => proj(row).getInt(0)
201203
}
202204

@@ -222,17 +224,23 @@ class DynamicPartitionDataWriter(
222224

223225
val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
224226

225-
// This must be in a form that matches our bucketing format. See BucketingUtils.
226-
val ext = f"$bucketIdStr.c$fileCounter%03d" +
227+
// The prefix and suffix must be in a form that matches our bucketing format.
228+
// See BucketingUtils.
229+
val prefix = bucketId match {
230+
case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id)
231+
case _ => ""
232+
}
233+
val suffix = f"$bucketIdStr.c$fileCounter%03d" +
227234
description.outputWriterFactory.getFileExtension(taskAttemptContext)
235+
val fileNameSpec = FileNameSpec(prefix, suffix)
228236

229237
val customPath = partDir.flatMap { dir =>
230238
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
231239
}
232240
val currentPath = if (customPath.isDefined) {
233-
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
241+
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, fileNameSpec)
234242
} else {
235-
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
243+
committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec)
236244
}
237245

238246
currentWriter = description.outputWriterFactory.newInstance(
@@ -277,6 +285,16 @@ class DynamicPartitionDataWriter(
277285
}
278286
}
279287

288+
/**
289+
* Bucketing specification for all the write tasks.
290+
*
291+
* @param bucketIdExpression Expression to calculate bucket id based on bucket column(s).
292+
* @param bucketFileNamePrefix Prefix of output file name based on bucket id.
293+
*/
294+
case class WriterBucketSpec(
295+
bucketIdExpression: Expression,
296+
bucketFileNamePrefix: Int => String)
297+
280298
/** A shared job description for all the write tasks. */
281299
class WriteJobDescription(
282300
val uuid: String, // prevent collision between different (appending) write jobs
@@ -285,7 +303,7 @@ class WriteJobDescription(
285303
val allColumns: Seq[Attribute],
286304
val dataColumns: Seq[Attribute],
287305
val partitionColumns: Seq[Attribute],
288-
val bucketIdExpression: Option[Expression],
306+
val bucketSpec: Option[WriterBucketSpec],
289307
val path: String,
290308
val customPartitionLocations: Map[TablePartitionSpec, String],
291309
val maxRecordsPerFile: Long,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
3939
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
4040
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
4141
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution}
42+
import org.apache.spark.sql.execution.command.DDLUtils
4243
import org.apache.spark.sql.internal.SQLConf
4344
import org.apache.spark.sql.types.StringType
4445
import org.apache.spark.unsafe.types.UTF8String
@@ -113,12 +114,33 @@ object FileFormatWriter extends Logging {
113114
}
114115
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan
115116

116-
val bucketIdExpression = bucketSpec.map { spec =>
117+
val writerBucketSpec = bucketSpec.map { spec =>
117118
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
118-
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
119-
// guarantee the data distribution is same between shuffle and bucketed data source, which
120-
// enables us to only shuffle one side when join a bucketed table and a normal one.
121-
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
119+
120+
if (options.getOrElse(DDLUtils.HIVE_PROVIDER, "false") == "true") {
121+
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
122+
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
123+
// columns is negative. See Hive implementation in
124+
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
125+
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
126+
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))
127+
128+
// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
129+
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
130+
//
131+
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
132+
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
133+
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
134+
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
135+
} else {
136+
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
137+
// expression, so that we can guarantee the data distribution is same between shuffle and
138+
// bucketed data source, which enables us to only shuffle one side when join a bucketed
139+
// table and a normal one.
140+
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
141+
.partitionIdExpression
142+
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
143+
}
122144
}
123145
val sortColumns = bucketSpec.toSeq.flatMap {
124146
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
@@ -139,7 +161,7 @@ object FileFormatWriter extends Logging {
139161
allColumns = outputSpec.outputColumns,
140162
dataColumns = dataColumns,
141163
partitionColumns = partitionColumns,
142-
bucketIdExpression = bucketIdExpression,
164+
bucketSpec = writerBucketSpec,
143165
path = outputSpec.outputPath,
144166
customPartitionLocations = outputSpec.customPartitionLocations,
145167
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
@@ -150,7 +172,8 @@ object FileFormatWriter extends Logging {
150172
)
151173

152174
// We should first sort by partition columns, then bucket id, and finally sorting columns.
153-
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
175+
val requiredOrdering =
176+
partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
154177
// the sort order doesn't matter
155178
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
156179
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
@@ -265,7 +288,7 @@ object FileFormatWriter extends Logging {
265288
if (sparkPartitionId != 0 && !iterator.hasNext) {
266289
// In case of empty job, leave first partition to save meta for file format like parquet.
267290
new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
268-
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
291+
} else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
269292
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
270293
} else {
271294
new DynamicPartitionDataWriter(description, taskAttemptContext, committer)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ abstract class FileWriteBuilder(
129129
allColumns = allColumns,
130130
dataColumns = allColumns,
131131
partitionColumns = Seq.empty,
132-
bucketIdExpression = None,
132+
bucketSpec = None,
133133
path = pathName,
134134
customPartitionLocations = Map.empty,
135135
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)

0 commit comments

Comments
 (0)